diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index fea153453ac..53cf6799099 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -242,12 +242,14 @@ message = """ * The `length` trait on `string` shapes. * The `length` trait on `map` shapes. +* The `range` trait on `integer` shapes. +* The `pattern` trait on `string` shapes. Upon receiving a request that violates the modeled constraints, the server SDK will reject it with a message indicating why. Unsupported (constraint trait, target shape) combinations will now fail at code generation time, whereas previously they were just ignored. This is a breaking change to raise awareness in service owners of their server SDKs behaving differently than what was modeled. To continue generating a server SDK with unsupported constraint traits, set `codegenConfig.ignoreUnsupportedConstraints` to `true` in your `smithy-build.json`. """ -references = ["smithy-rs#1199", "smithy-rs#1342", "smithy-rs#1401"] +references = ["smithy-rs#1199", "smithy-rs#1342", "smithy-rs#1401", "smithy-rs#2005", "smithy-rs#1998"] meta = { "breaking" = true, "tada" = true, "bug" = false, "target" = "server" } author = "david-perez" @@ -478,3 +480,9 @@ x-amzn-errortype: com.example.service#InvalidRequestException references = ["smithy-rs#1982"] meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server" } author = "david-perez" + +[[smithy-rs]] +message = "Make generated enum `values()` functions callable in const contexts." +references = ["smithy-rs#2011"] +meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "all" } +author = "lsr0" diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt index 3a9662d8691..2ec1b5e290e 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt @@ -107,7 +107,7 @@ class AwsFluentClientDecorator : RustCodegenDecorator` for the nested item diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt index 262afbd616b..fa4e7ae29ac 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt @@ -13,10 +13,8 @@ import software.amazon.smithy.model.traits.IdempotencyTokenTrait import software.amazon.smithy.model.traits.PaginatedTrait import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerics import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustType -import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust @@ -78,11 +76,7 @@ class PaginatorGenerator private constructor( private val idx = PaginatedIndex.of(model) private val paginationInfo = idx.getPaginationInfo(service, operation).orNull() ?: PANIC("failed to load pagination info") - private val module = RustModule( - "paginator", - RustMetadata(visibility = Visibility.PUBLIC), - documentation = "Paginators for the service", - ) + private val module = RustModule.public("paginator", "Paginators for the service") private val inputType = symbolProvider.toSymbol(operation.inputShape(model)) private val outputShape = operation.outputShape(model) @@ -99,7 +93,12 @@ class PaginatorGenerator private constructor( "generics" to generics.decl, "bounds" to generics.bounds, "page_size_setter" to pageSizeSetter(), - "send_bounds" to generics.sendBounds(symbolProvider.toSymbol(operation), outputType, errorType, retryClassifier), + "send_bounds" to generics.sendBounds( + symbolProvider.toSymbol(operation), + outputType, + errorType, + retryClassifier, + ), // Operation Types "operation" to symbolProvider.toSymbol(operation), @@ -288,7 +287,8 @@ class PaginatorGenerator private constructor( private fun pageSizeSetter() = writable { paginationInfo.pageSizeMember.orNull()?.also { val memberName = symbolProvider.toMemberName(it) - val pageSizeT = symbolProvider.toSymbol(it).rustType().stripOuter().render(true) + val pageSizeT = + symbolProvider.toSymbol(it).rustType().stripOuter().render(true) rust( """ /// Set the page size diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt index 1b1a9872549..74dcb5a3dc7 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt @@ -11,8 +11,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustGenerics import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility -import software.amazon.smithy.rust.codegen.core.rustlang.docs -import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RustCrate @@ -26,20 +24,16 @@ class CustomizableOperationGenerator( private val generics: FluentClientGenerics, private val includeFluentClient: Boolean, ) { + companion object { - const val CUSTOMIZE_MODULE = "crate::operation::customize" + val CustomizeModule = RustModule.public("customize", "Operation customization and supporting types", parent = RustModule.operation(Visibility.PUBLIC)) } private val smithyHttp = CargoDependency.smithyHttp(runtimeConfig).toType() private val smithyTypes = CargoDependency.smithyTypes(runtimeConfig).toType() fun render(crate: RustCrate) { - crate.withModule(RustModule.operation(Visibility.PUBLIC)) { - docs("Operation customization and supporting types") - rust("pub mod customize;") - } - - crate.withNonRootModule(CUSTOMIZE_MODULE) { + crate.withModule(CustomizeModule) { rustTemplate( """ pub use #{Operation}; diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt index b79537dd010..03ce4859ea3 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt @@ -21,6 +21,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.asArgumentType import software.amazon.smithy.rust.codegen.core.rustlang.asOptional import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape @@ -220,7 +221,7 @@ class FluentClientGenerator( ) } } - writer.withModule(RustModule.public("fluent_builders")) { + writer.withInlineModule(RustModule.new("fluent_builders", visibility = Visibility.PUBLIC, inline = true)) { docs( """ Utilities to ergonomically construct a request to the service. diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt index 830308060f4..82ecfe6967c 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -100,7 +100,7 @@ class ProtocolTestGenerator( Attribute.Custom("allow(unreachable_code, unused_variables)"), ), ) - writer.withModule(RustModule(testModuleName, moduleMeta)) { + writer.withInlineModule(RustModule.LeafModule(testModuleName, moduleMeta, inline = true)) { renderAllTestCases(allTests) } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/SmithyTypesPubUseGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/SmithyTypesPubUseGeneratorTest.kt index d402fb8acfd..6a7ca5be25f 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/SmithyTypesPubUseGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/SmithyTypesPubUseGeneratorTest.kt @@ -13,7 +13,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel class SmithyTypesPubUseGeneratorTest { - private fun emptyModel() = modelWithMember() private fun modelWithMember( inputMember: String = "", outputMember: String = "", diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/CodegenVisitorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/CodegenVisitorTest.kt index eccb9058d5b..b86e3b35c77 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/CodegenVisitorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/CodegenVisitorTest.kt @@ -15,8 +15,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.customize.RequiredCusto import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientDecorator import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.generatePluginContext -import kotlin.io.path.createDirectory -import kotlin.io.path.writeText class CodegenVisitorTest { @Test @@ -48,8 +46,6 @@ class CodegenVisitorTest { } """.asSmithyModel(smithyVersion = "2.0") val (ctx, testDir) = generatePluginContext(model) - testDir.resolve("src").createDirectory() - testDir.resolve("src/main.rs").writeText("fn main() {}") val codegenDecorator = CombinedCodegenDecorator.fromClasspath( ctx, diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/SymbolVisitorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/SymbolVisitorTest.kt index e7e87a89a68..bc26847bef5 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/SymbolVisitorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/SymbolVisitorTest.kt @@ -29,9 +29,9 @@ import software.amazon.smithy.model.traits.SparseTrait import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.render -import software.amazon.smithy.rust.codegen.core.smithy.Errors -import software.amazon.smithy.rust.codegen.core.smithy.Models -import software.amazon.smithy.rust.codegen.core.smithy.Operations +import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.smithy.OperationsModule import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel @@ -57,7 +57,7 @@ class SymbolVisitorTest { val provider: SymbolProvider = testSymbolProvider(model) val sym = provider.toSymbol(struct) sym.rustType().render(false) shouldBe "MyStruct" - sym.definitionFile shouldContain Models.filename + sym.definitionFile shouldContain ModelsModule.definitionFile() sym.namespace shouldBe "crate::model" } @@ -77,7 +77,7 @@ class SymbolVisitorTest { val provider: SymbolProvider = testSymbolProvider(model) val sym = provider.toSymbol(struct) sym.rustType().render(false) shouldBe "TerribleError" - sym.definitionFile shouldContain Errors.filename + sym.definitionFile shouldContain ErrorsModule.definitionFile() } @Test @@ -101,7 +101,7 @@ class SymbolVisitorTest { val provider: SymbolProvider = testSymbolProvider(model) val sym = provider.toSymbol(shape) sym.rustType().render(false) shouldBe "StandardUnit" - sym.definitionFile shouldContain Models.filename + sym.definitionFile shouldContain ModelsModule.definitionFile() sym.namespace shouldBe "crate::model" } @@ -260,7 +260,7 @@ class SymbolVisitorTest { } """.asSmithyModel() val symbol = testSymbolProvider(model).toSymbol(model.expectShape(ShapeId.from("smithy.example#PutObject"))) - symbol.definitionFile shouldBe("src/${Operations.filename}") + symbol.definitionFile shouldBe(OperationsModule.definitionFile()) symbol.name shouldBe "PutObject" } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingsTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingsTest.kt index 41aa7eacb96..5448636ac11 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingsTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingsTest.kt @@ -12,7 +12,6 @@ import software.amazon.smithy.model.traits.EndpointTrait import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock @@ -59,7 +58,7 @@ internal class EndpointTraitBindingsTest { operationShape.expectTrait(EndpointTrait::class.java), ) val project = TestWorkspace.testProject() - project.withModule(RustModule.default("test", visibility = Visibility.PRIVATE)) { + project.withModule(RustModule.private("test")) { rust( """ struct GetStatusInput { diff --git a/codegen-core/common-test-models/constraints.smithy b/codegen-core/common-test-models/constraints.smithy index ddfe920cdde..ccef90c1477 100644 --- a/codegen-core/common-test-models/constraints.smithy +++ b/codegen-core/common-test-models/constraints.smithy @@ -19,11 +19,19 @@ service ConstraintsService { // combination. QueryParamsTargetingLengthMapOperation, QueryParamsTargetingMapOfLengthStringOperation, - QueryParamsTargetingMapOfEnumStringOperation, QueryParamsTargetingMapOfListOfLengthStringOperation, QueryParamsTargetingMapOfSetOfLengthStringOperation, QueryParamsTargetingMapOfListOfEnumStringOperation, + + QueryParamsTargetingMapOfPatternStringOperation, + QueryParamsTargetingMapOfListOfPatternStringOperation, + QueryParamsTargetingMapOfLengthPatternStringOperation, + QueryParamsTargetingMapOfListOfLengthPatternStringOperation, + HttpPrefixHeadersTargetingLengthMapOperation, + + QueryParamsTargetingMapOfEnumStringOperation, + QueryParamsTargetingMapOfListOfEnumStringOperation, // TODO(https://github.com/awslabs/smithy-rs/issues/1431) // HttpPrefixHeadersTargetingMapOfEnumStringOperation, @@ -41,7 +49,7 @@ operation ConstrainedShapesOperation { errors: [ValidationException] } -@http(uri: "/constrained-http-bound-shapes-operation/{lengthStringLabel}/{enumStringLabel}", method: "POST") +@http(uri: "/constrained-http-bound-shapes-operation/{rangeIntegerLabel}/{lengthStringLabel}/{enumStringLabel}", method: "POST") operation ConstrainedHttpBoundShapesOperation { input: ConstrainedHttpBoundShapesOperationInputOutput, output: ConstrainedHttpBoundShapesOperationInputOutput, @@ -97,6 +105,34 @@ operation QueryParamsTargetingMapOfListOfEnumStringOperation { errors: [ValidationException] } +@http(uri: "/query-params-targeting-map-of-pattern-string-operation", method: "POST") +operation QueryParamsTargetingMapOfPatternStringOperation { + input: QueryParamsTargetingMapOfPatternStringOperationInputOutput, + output: QueryParamsTargetingMapOfPatternStringOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/query-params-targeting-map-of-list-of-pattern-string-operation", method: "POST") +operation QueryParamsTargetingMapOfListOfPatternStringOperation { + input: QueryParamsTargetingMapOfListOfPatternStringOperationInputOutput, + output: QueryParamsTargetingMapOfListOfPatternStringOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/query-params-targeting-map-of-length-pattern-string", method: "POST") +operation QueryParamsTargetingMapOfLengthPatternStringOperation { + input: QueryParamsTargetingMapOfLengthPatternStringOperationInputOutput, + output: QueryParamsTargetingMapOfLengthPatternStringOperationInputOutput, + errors: [ValidationException], +} + +@http(uri: "/query-params-targeting-map-of-list-of-length-pattern-string-operation", method: "POST") +operation QueryParamsTargetingMapOfListOfLengthPatternStringOperation { + input: QueryParamsTargetingMapOfListOfLengthPatternStringOperationInputOutput, + output: QueryParamsTargetingMapOfListOfLengthPatternStringOperationInputOutput, + errors: [ValidationException] +} + @http(uri: "/http-prefix-headers-targeting-length-map-operation", method: "POST") operation HttpPrefixHeadersTargetingLengthMapOperation { input: HttpPrefixHeadersTargetingLengthMapOperationInputOutput, @@ -139,18 +175,24 @@ structure ConstrainedHttpBoundShapesOperationInputOutput { @httpLabel lengthStringLabel: LengthString, + @required + @httpLabel + rangeIntegerLabel: RangeInteger, + @required @httpLabel enumStringLabel: EnumString, - // TODO(https://github.com/awslabs/smithy-rs/issues/1394) `@required` not working - // @required - @httpPrefixHeaders("X-Prefix-Headers-") + @required + @httpPrefixHeaders("X-Length-String-Prefix-Headers-") lengthStringHeaderMap: MapOfLengthString, @httpHeader("X-Length") lengthStringHeader: LengthString, + @httpHeader("X-Range-Integer") + rangeIntegerHeader: RangeInteger, + // @httpHeader("X-Length-MediaType") // lengthStringHeaderWithMediaType: MediaTypeLengthString, @@ -162,6 +204,14 @@ structure ConstrainedHttpBoundShapesOperationInputOutput { @httpHeader("X-Length-List") lengthStringListHeader: ListOfLengthString, + // TODO(https://github.com/awslabs/smithy-rs/issues/1401): a `set` shape is + // just a `list` shape with `uniqueItems`, which hasn't been implemented yet. + // @httpHeader("X-Range-Integer-Set") + // rangeIntegerSetHeader: SetOfRangeInteger, + + @httpHeader("X-Range-Integer-List") + rangeIntegerListHeader: ListOfRangeInteger, + // TODO(https://github.com/awslabs/smithy-rs/issues/1431) // @httpHeader("X-Enum") //enumStringHeader: EnumString, @@ -172,6 +222,9 @@ structure ConstrainedHttpBoundShapesOperationInputOutput { @httpQuery("lengthString") lengthStringQuery: LengthString, + @httpQuery("rangeInteger") + rangeIntegerQuery: RangeInteger, + @httpQuery("enumString") enumStringQuery: EnumString, @@ -183,10 +236,38 @@ structure ConstrainedHttpBoundShapesOperationInputOutput { // @httpQuery("lengthStringSet") // lengthStringSetQuery: SetOfLengthString, + @httpQuery("rangeIntegerList") + rangeIntegerListQuery: ListOfRangeInteger, + + // TODO(https://github.com/awslabs/smithy-rs/issues/1401): a `set` shape is + // just a `list` shape with `uniqueItems`, which hasn't been implemented yet. + // @httpQuery("rangeIntegerSet") + // rangeIntegerSetQuery: SetOfRangeInteger, + @httpQuery("enumStringList") enumStringListQuery: ListOfEnumString, } +structure QueryParamsTargetingMapOfPatternStringOperationInputOutput { + @httpQueryParams + mapOfPatternString: MapOfPatternString +} + +structure QueryParamsTargetingMapOfListOfPatternStringOperationInputOutput { + @httpQueryParams + mapOfListOfPatternString: MapOfListOfPatternString +} + +structure QueryParamsTargetingMapOfLengthPatternStringOperationInputOutput { + @httpQueryParams + mapOfLengthPatternString: MapOfLengthPatternString, +} + +structure QueryParamsTargetingMapOfListOfLengthPatternStringOperationInputOutput { + @httpQueryParams + mapOfLengthPatternString: MapOfListOfLengthPatternString, +} + structure HttpPrefixHeadersTargetingLengthMapOperationInputOutput { @httpPrefixHeaders("X-Prefix-Headers-LengthMap-") lengthMap: ConBMap, @@ -278,6 +359,11 @@ structure ConA { maxLengthString: MaxLengthString, fixedLengthString: FixedLengthString, + rangeInteger: RangeInteger, + minRangeInteger: MinRangeInteger, + maxRangeInteger: MaxRangeInteger, + fixedValueInteger: FixedValueInteger, + conBList: ConBList, conBList2: ConBList2, @@ -298,7 +384,27 @@ structure ConA { // setOfLengthString: SetOfLengthString, mapOfLengthString: MapOfLengthString, + listOfRangeInteger: ListOfRangeInteger, + // TODO(https://github.com/awslabs/smithy-rs/issues/1401): a `set` shape is + // just a `list` shape with `uniqueItems`, which hasn't been implemented yet. + // setOfRangeInteger: SetOfRangeInteger, + mapOfRangeInteger: MapOfRangeInteger, + nonStreamingBlob: NonStreamingBlob + + patternString: PatternString, + mapOfPatternString: MapOfPatternString, + listOfPatternString: ListOfPatternString, + // TODO(https://github.com/awslabs/smithy-rs/issues/1401): a `set` shape is + // just a `list` shape with `uniqueItems`, which hasn't been implemented yet. + // setOfPatternString: SetOfPatternString, + + lengthLengthPatternString: LengthPatternString, + mapOfLengthPatternString: MapOfLengthPatternString, + listOfLengthPatternString: ListOfLengthPatternString + // TODO(https://github.com/awslabs/smithy-rs/issues/1401): a `set` shape is + // just a `list` shape with `uniqueItems`, which hasn't been implemented yet. + // setOfLengthPatternString: SetOfLengthPatternString, } map MapOfLengthString { @@ -306,6 +412,11 @@ map MapOfLengthString { value: LengthString, } +map MapOfRangeInteger { + key: String, + value: RangeInteger, +} + map MapOfEnumString { key: EnumString, value: EnumString, @@ -321,6 +432,16 @@ map MapOfListOfEnumString { value: ListOfEnumString, } +map MapOfListOfPatternString { + key: PatternString, + value: ListOfPatternString +} + +map MapOfListOfLengthPatternString { + key: LengthPatternString, + value: ListOfLengthPatternString +} + map MapOfSetOfLengthString { key: LengthString, // TODO(https://github.com/awslabs/smithy-rs/issues/1401): a `set` shape is @@ -329,6 +450,13 @@ map MapOfSetOfLengthString { value: ListOfLengthString } +// TODO(https://github.com/awslabs/smithy-rs/issues/1401): a `set` shape is +// just a `list` shape with `uniqueItems`, which hasn't been implemented yet. +// map MapOfSetOfRangeInteger { +// key: LengthString, +// value: SetOfRangeInteger, +// } + @length(min: 2, max: 8) list LengthListOfLengthString { member: LengthString @@ -340,16 +468,35 @@ string LengthString @length(min: 2) string MinLengthString -@length(min: 69) +@length(max: 69) string MaxLengthString @length(min: 69, max: 69) string FixedLengthString +@pattern("[a-d]{5}") +string PatternString + +@pattern("[a-f0-5]*") +@length(min: 5, max: 10) +string LengthPatternString + @mediaType("video/quicktime") @length(min: 1, max: 69) string MediaTypeLengthString +@range(min: -0, max: 69) +integer RangeInteger + +@range(min: -10) +integer MinRangeInteger + +@range(max: 69) +integer MaxRangeInteger + +@range(min: 69, max: 69) +integer FixedValueInteger + /// A union with constrained members. union ConstrainedUnion { enumString: EnumString, @@ -383,14 +530,40 @@ set SetOfLengthString { member: LengthString } +set SetOfPatternString { + member: PatternString +} + +set SetOfLengthPatternString { + member: LengthPatternString +} + list ListOfLengthString { member: LengthString } +// TODO(https://github.com/awslabs/smithy-rs/issues/1401): a `set` shape is +// just a `list` shape with `uniqueItems`, which hasn't been implemented yet. +// set SetOfRangeInteger { +// member: RangeInteger +// } + +list ListOfRangeInteger { + member: RangeInteger +} + list ListOfEnumString { member: EnumString } +list ListOfPatternString { + member: PatternString +} + +list ListOfLengthPatternString { + member: LengthPatternString +} + structure ConB { @required nice: String, @@ -443,6 +616,16 @@ list NestedList { // member: String // } +map MapOfPatternString { + key: PatternString, + value: PatternString, +} + +map MapOfLengthPatternString { + key: LengthPatternString, + value: LengthPatternString, +} + @length(min: 1, max: 69) map ConBMap { key: String, diff --git a/codegen-core/common-test-models/malformed-range-extras.smithy b/codegen-core/common-test-models/malformed-range-extras.smithy new file mode 100644 index 00000000000..8fd9d93c112 --- /dev/null +++ b/codegen-core/common-test-models/malformed-range-extras.smithy @@ -0,0 +1,662 @@ +$version: "2.0" + +namespace aws.protocoltests.extras.restjson.validation + +use aws.api#service +use aws.protocols#restJson1 +use smithy.test#httpMalformedRequestTests +use smithy.framework#ValidationException + +/// A REST JSON service that sends JSON requests and responses with validation applied +@service(sdkId: "Rest Json Validation Protocol") +@restJson1 +service MalformedRangeValidation { + version: "2022-11-23", + operations: [ + MalformedRange, + MalformedRangeOverride, + ] +} + +@suppress(["UnstableTrait"]) +@http(uri: "/MalformedRange", method: "POST") +operation MalformedRange { + input: MalformedRangeInput, + errors: [ValidationException] +} + +@suppress(["UnstableTrait"]) +@http(uri: "/MalformedRangeOverride", method: "POST") +operation MalformedRangeOverride { + input: MalformedRangeOverrideInput, + errors: [ValidationException] +} + +apply MalformedRange @httpMalformedRequestTests([ + { + id: "RestJsonMalformedRangeShort", + documentation: """ + When a short member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "short" : $value:L }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value $value:L at '/short' failed to satisfy constraint: Member must be between 2 and 8, inclusive", + "fieldList" : [{"message": "Value $value:L at '/short' failed to satisfy constraint: Member must be between 2 and 8, inclusive", "path": "/short"}]}""" + } + } + }, + testParameters: { + value: ["1", "9"] + } + }, + { + id: "RestJsonMalformedRangeMinShort", + documentation: """ + When a short member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "minShort" : 1 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 1 at '/minShort' failed to satisfy constraint: Member must be greater than or equal to 2", + "fieldList" : [{"message": "Value 1 at '/minShort' failed to satisfy constraint: Member must be greater than or equal to 2", "path": "/minShort"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeMaxShort", + documentation: """ + When a short member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "maxShort" : 9 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 9 at '/maxShort' failed to satisfy constraint: Member must be less than or equal to 8", + "fieldList" : [{"message": "Value 9 at '/maxShort' failed to satisfy constraint: Member must be less than or equal to 8", "path": "/maxShort"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeInteger", + documentation: """ + When a integer member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "integer" : $value:L }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value $value:L at '/integer' failed to satisfy constraint: Member must be between 2 and 8, inclusive", + "fieldList" : [{"message": "Value $value:L at '/integer' failed to satisfy constraint: Member must be between 2 and 8, inclusive", "path": "/integer"}]}""" + } + } + }, + testParameters: { + value: ["1", "9"] + } + }, + { + id: "RestJsonMalformedRangeMinInteger", + documentation: """ + When a integer member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "minInteger" : 1 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 1 at '/minInteger' failed to satisfy constraint: Member must be greater than or equal to 2", + "fieldList" : [{"message": "Value 1 at '/minInteger' failed to satisfy constraint: Member must be greater than or equal to 2", "path": "/minInteger"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeMaxInteger", + documentation: """ + When a integer member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "maxInteger" : 9 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 9 at '/maxInteger' failed to satisfy constraint: Member must be less than or equal to 8", + "fieldList" : [{"message": "Value 9 at '/maxInteger' failed to satisfy constraint: Member must be less than or equal to 8", "path": "/maxInteger"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeLong", + documentation: """ + When a long member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "long" : $value:L }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value $value:L at '/long' failed to satisfy constraint: Member must be between 2 and 8, inclusive", + "fieldList" : [{"message": "Value $value:L at '/long' failed to satisfy constraint: Member must be between 2 and 8, inclusive", "path": "/long"}]}""" + } + } + }, + testParameters: { + value: ["1", "9"] + } + }, + { + id: "RestJsonMalformedRangeMinLong", + documentation: """ + When a long member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "minLong" : 1 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 1 at '/minLong' failed to satisfy constraint: Member must be greater than or equal to 2", + "fieldList" : [{"message": "Value 1 at '/minLong' failed to satisfy constraint: Member must be greater than or equal to 2", "path": "/minLong"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeMaxLong", + documentation: """ + When a long member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "maxLong" : 9 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 9 at '/maxLong' failed to satisfy constraint: Member must be less than or equal to 8", + "fieldList" : [{"message": "Value 9 at '/maxLong' failed to satisfy constraint: Member must be less than or equal to 8", "path": "/maxLong"}]}""" + } + } + } + }, +]) + +// now repeat the above tests, but for the more specific constraints applied to the input member +apply MalformedRangeOverride @httpMalformedRequestTests([ + { + id: "RestJsonMalformedRangeShortOverride", + documentation: """ + When a short member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "short" : $value:L }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value $value:L at '/short' failed to satisfy constraint: Member must be between 4 and 6, inclusive", + "fieldList" : [{"message": "Value $value:L at '/short' failed to satisfy constraint: Member must be between 4 and 6, inclusive", "path": "/short"}]}""" + } + } + }, + testParameters: { + value: ["3", "7"] + } + }, + { + id: "RestJsonMalformedRangeMinShortOverride", + documentation: """ + When a short member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "minShort" : 3 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 3 at '/minShort' failed to satisfy constraint: Member must be greater than or equal to 4", + "fieldList" : [{"message": "Value 3 at '/minShort' failed to satisfy constraint: Member must be greater than or equal to 4", "path": "/minShort"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeMaxShortOverride", + documentation: """ + When a short member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "maxShort" : 7 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 7 at '/maxShort' failed to satisfy constraint: Member must be less than or equal to 6", + "fieldList" : [{"message": "Value 7 at '/maxShort' failed to satisfy constraint: Member must be less than or equal to 6", "path": "/maxShort"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeIntegerOverride", + documentation: """ + When a integer member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "integer" : $value:L }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value $value:L at '/integer' failed to satisfy constraint: Member must be between 4 and 6, inclusive", + "fieldList" : [{"message": "Value $value:L at '/integer' failed to satisfy constraint: Member must be between 4 and 6, inclusive", "path": "/integer"}]}""" + } + } + }, + testParameters: { + value: ["3", "7"] + } + }, + { + id: "RestJsonMalformedRangeMinIntegerOverride", + documentation: """ + When a integer member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "minInteger" : 3 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 3 at '/minInteger' failed to satisfy constraint: Member must be greater than or equal to 4", + "fieldList" : [{"message": "Value 3 at '/minInteger' failed to satisfy constraint: Member must be greater than or equal to 4", "path": "/minInteger"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeMaxIntegerOverride", + documentation: """ + When a integer member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "maxInteger" : 7 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 7 at '/maxInteger' failed to satisfy constraint: Member must be less than or equal to 6", + "fieldList" : [{"message": "Value 7 at '/maxInteger' failed to satisfy constraint: Member must be less than or equal to 6", "path": "/maxInteger"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeLongOverride", + documentation: """ + When a long member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "long" : $value:L }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value $value:L at '/long' failed to satisfy constraint: Member must be between 4 and 6, inclusive", + "fieldList" : [{"message": "Value $value:L at '/long' failed to satisfy constraint: Member must be between 4 and 6, inclusive", "path": "/long"}]}""" + } + } + }, + testParameters: { + value: ["3", "7"] + } + }, + { + id: "RestJsonMalformedRangeMinLongOverride", + documentation: """ + When a long member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "minLong" : 3 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 3 at '/minLong' failed to satisfy constraint: Member must be greater than or equal to 4", + "fieldList" : [{"message": "Value 3 at '/minLong' failed to satisfy constraint: Member must be greater than or equal to 4", "path": "/minLong"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeMaxLongOverride", + documentation: """ + When a long member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "maxLong" : 7 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 7 at '/maxLong' failed to satisfy constraint: Member must be less than or equal to 6", + "fieldList" : [{"message": "Value 7 at '/maxLong' failed to satisfy constraint: Member must be less than or equal to 6", "path": "/maxLong"}]}""" + } + } + } + }, +]) + +structure MalformedRangeInput { + short: RangeShort, + minShort: MinShort, + maxShort: MaxShort, + + integer: RangeInteger, + minInteger: MinInteger, + maxInteger: MaxInteger, + + long: RangeLong, + minLong: MinLong, + maxLong: MaxLong, +} + +structure MalformedRangeOverrideInput { + @range(min: 4, max: 6) + short: RangeShort, + @range(min: 4) + minShort: MinShort, + @range(max: 6) + maxShort: MaxShort, + + @range(min: 4, max: 6) + integer: RangeInteger, + @range(min: 4) + minInteger: MinInteger, + @range(max: 6) + maxInteger: MaxInteger, + + @range(min: 4, max: 6) + long: RangeLong, + @range(min: 4) + minLong: MinLong, + @range(max: 6) + maxLong: MaxLong, +} + +@range(min: 2, max: 8) +short RangeShort + +@range(min: 2) +short MinShort + +@range(max: 8) +short MaxShort + +@range(min: 2, max: 8) +integer RangeInteger + +@range(min: 2) +integer MinInteger + +@range(max: 8) +integer MaxInteger + +@range(min: 2, max: 8) +long RangeLong + +@range(min: 2) +long MinLong + +@range(max: 8) +long MaxLong diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index 0dddd335d39..3b17bb2855d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.core.rustlang import software.amazon.smithy.codegen.core.SymbolDependency import software.amazon.smithy.codegen.core.SymbolDependencyContainer +import software.amazon.smithy.rust.codegen.core.smithy.ConstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.util.dq @@ -67,57 +68,52 @@ class InlineDependency( return extraDependencies } - fun key() = "${module.name}::$name" + fun key() = "${module.fullyQualifiedPath()}::$name" companion object { fun forRustFile( - name: String, - baseDir: String, - vararg additionalDependencies: RustDependency, - ): InlineDependency = forRustFile(name, baseDir, visibility = Visibility.PRIVATE, *additionalDependencies) - - fun forRustFile( - name: String, - baseDir: String, - visibility: Visibility, + module: RustModule.LeafModule, + resourcePath: String, vararg additionalDependencies: RustDependency, ): InlineDependency { - val module = RustModule.default(name, visibility) - val filename = if (name.endsWith(".rs")) { name } else { "$name.rs" } // The inline crate is loaded as a dependency on the runtime classpath - val rustFile = this::class.java.getResource("/$baseDir/src/$filename") - check(rustFile != null) { "Rust file /$baseDir/src/$filename was missing from the resource bundle!" } - return InlineDependency(name, module, additionalDependencies.toList()) { + val rustFile = this::class.java.getResource(resourcePath) + check(rustFile != null) { "Rust file $resourcePath was missing from the resource bundle!" } + return InlineDependency(module.name, module, additionalDependencies.toList()) { raw(rustFile.readText()) } } - fun forRustFile(name: String, vararg additionalDependencies: RustDependency) = - forRustFile(name, "inlineable", *additionalDependencies) - - fun eventStream(runtimeConfig: RuntimeConfig) = - forRustFile("event_stream", CargoDependency.smithyEventStream(runtimeConfig)) + private fun forInlineableRustFile(name: String, vararg additionalDependencies: RustDependency) = + forRustFile(RustModule.private(name), "/inlineable/src/$name.rs", *additionalDependencies) fun jsonErrors(runtimeConfig: RuntimeConfig) = - forRustFile("json_errors", CargoDependency.Http, CargoDependency.smithyTypes(runtimeConfig)) + forInlineableRustFile( + "json_errors", + CargoDependency.smithyJson(runtimeConfig), + CargoDependency.Bytes, + CargoDependency.Http, + ) fun idempotencyToken() = - forRustFile("idempotency_token", CargoDependency.FastRand) + forInlineableRustFile("idempotency_token", CargoDependency.FastRand) fun ec2QueryErrors(runtimeConfig: RuntimeConfig): InlineDependency = - forRustFile("ec2_query_errors", CargoDependency.smithyXml(runtimeConfig)) + forInlineableRustFile("ec2_query_errors", CargoDependency.smithyXml(runtimeConfig)) fun wrappedXmlErrors(runtimeConfig: RuntimeConfig): InlineDependency = - forRustFile("rest_xml_wrapped_errors", CargoDependency.smithyXml(runtimeConfig)) + forInlineableRustFile("rest_xml_wrapped_errors", CargoDependency.smithyXml(runtimeConfig)) fun unwrappedXmlErrors(runtimeConfig: RuntimeConfig): InlineDependency = - forRustFile("rest_xml_unwrapped_errors", CargoDependency.smithyXml(runtimeConfig)) + forInlineableRustFile("rest_xml_unwrapped_errors", CargoDependency.smithyXml(runtimeConfig)) fun constrained(): InlineDependency = - forRustFile("constrained") + InlineDependency.forRustFile(ConstrainedModule, "/inlineable/src/constrained.rs") } } +fun InlineDependency.asType() = RuntimeType(name = null, dependency = this, namespace = module.fullyQualifiedPath()) + data class Feature(val name: String, val default: Boolean, val deps: List) /** @@ -221,7 +217,7 @@ data class CargoDependency( val Smol: CargoDependency = CargoDependency("smol", CratesIo("1.2.0"), DependencyScope.Dev) val TempFile: CargoDependency = CargoDependency("tempfile", CratesIo("3.2.0"), DependencyScope.Dev) val Tokio: CargoDependency = - CargoDependency("tokio", CratesIo("1.8.4"), DependencyScope.Dev, features = setOf("macros", "test-util")) + CargoDependency("tokio", CratesIo("1.8.4"), DependencyScope.Dev, features = setOf("macros", "test-util", "rt-multi-thread")) val TracingSubscriber: CargoDependency = CargoDependency( "tracing-subscriber", CratesIo("0.3.15"), diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt index aa7a0fb8d33..b9ecd1b5104 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt @@ -5,23 +5,72 @@ package software.amazon.smithy.rust.codegen.core.rustlang -data class RustModule(val name: String, val rustMetadata: RustMetadata, val documentation: String? = null) { - fun render(writer: RustWriter) { - documentation?.let { docs -> writer.docs(docs) } - rustMetadata.render(writer) - writer.write("mod $name;") +/** + * RustModule system. + * + * RustModules are idempotent, BUT, you must take care to always use the same module (matching docs, visibility, etc.): + * - There is no guarantee _which_ module will be rendered. + */ +sealed class RustModule { + + /** lib.rs */ + object LibRs : RustModule() + + /** + * LeafModule + * + * A LeafModule is _all_ modules that are not `lib.rs`. To create a nested leaf module, set `parent` to a module + * _other_ than `LibRs`. + * + * To avoid infinite loops, avoid setting parent to itself ;-) + */ + data class LeafModule( + val name: String, + val rustMetadata: RustMetadata, + val documentation: String? = null, + val parent: RustModule = LibRs, + val inline: Boolean = false, + ) : RustModule() { + init { + check(!name.contains("::")) { + "Module names CANNOT contain `::`—modules must be nested with parent (name was: `$name`)" + } + check(name != "") { + "Module name cannot be empty" + } + + check(!RustReservedWords.isReserved(name)) { + "Module `$name` cannot be a module name—it is a reserved word." + } + } } companion object { - fun default(name: String, visibility: Visibility, documentation: String? = null): RustModule { - return RustModule(name, RustMetadata(visibility = visibility), documentation) + + /** Creates a new module with the specified visibility */ + fun new( + name: String, + visibility: Visibility, + documentation: String? = null, + inline: Boolean = false, + parent: RustModule = LibRs, + ): LeafModule { + return LeafModule( + RustReservedWords.escapeIfNeeded(name), + RustMetadata(visibility = visibility), + documentation, + inline = inline, + parent = parent, + ) } - fun public(name: String, documentation: String? = null): RustModule = - default(name, visibility = Visibility.PUBLIC, documentation = documentation) + /** Creates a new public module */ + fun public(name: String, documentation: String? = null, parent: RustModule = LibRs): LeafModule = + new(name, visibility = Visibility.PUBLIC, documentation = documentation, inline = false, parent = parent) - fun private(name: String, documentation: String? = null): RustModule = - default(name, visibility = Visibility.PRIVATE, documentation = documentation) + /** Creates a new private module */ + fun private(name: String, documentation: String? = null, parent: RustModule = LibRs): LeafModule = + new(name, visibility = Visibility.PRIVATE, documentation = documentation, inline = false, parent = parent) /* Common modules used across client, server and tests */ val Config = public("config", documentation = "Configuration for the service.") @@ -36,6 +85,53 @@ data class RustModule(val name: String, val rustMetadata: RustMetadata, val docu * Its visibility depends on the generation context (client or server). */ fun operation(visibility: Visibility): RustModule = - default("operation", visibility = visibility, documentation = "All operations that this crate can perform.") + new( + "operation", + visibility = visibility, + documentation = "All operations that this crate can perform.", + ) + } + + fun isInline(): Boolean = when (this) { + is LibRs -> false + is LeafModule -> this.inline + } + + /** + * Fully qualified path to this module, e.g. `crate::grandparent::parent::child` + */ + fun fullyQualifiedPath(): String = when (this) { + is LibRs -> "crate" + is LeafModule -> parent.fullyQualifiedPath() + "::" + name + } + + /** + * The file this module is homed in, e.g. `src/grandparent/parent/child.rs` + */ + fun definitionFile(): String = when (this) { + is LibRs -> "src/lib.rs" + is LeafModule -> { + val path = fullyQualifiedPath().split("::").drop(1).joinToString("/") + "src/$path.rs" + } + } + + /** + * Renders the usage statement, approximately: + * ```rust + * /// My docs + * pub mod my_module_name + * ``` + */ + fun renderModStatement(writer: RustWriter) { + when (this) { + is LeafModule -> { + documentation?.let { docs -> writer.docs(docs) } + rustMetadata.render(writer) + writer.write("mod $name;") + } + + else -> {} + } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt index 33f15aa89e5..efe9ae7cc88 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt @@ -39,8 +39,10 @@ class RustReservedWordSymbolProvider(private val base: RustSymbolProvider, priva // To avoid conflicts with the `make_operation` and `presigned` functions on generated inputs "make_operation" -> "make_operation_value" "presigned" -> "presigned_value" + "customize" -> "customize_value" else -> baseName } + is UnionShape -> when (baseName) { // Unions contain an `Unknown` variant. This exists to support parsing data returned from the server // that represent union variants that have been added since this SDK was generated. @@ -53,6 +55,7 @@ class RustReservedWordSymbolProvider(private val base: RustSymbolProvider, priva "SelfValue" -> "SelfValue_" else -> baseName } + else -> error("unexpected container: $container") } } @@ -78,6 +81,7 @@ class RustReservedWordSymbolProvider(private val base: RustSymbolProvider, priva it.toBuilder().renamedFrom(previousName).build() } } + else -> base.toSymbol(shape) } } @@ -150,7 +154,6 @@ object RustReservedWords : ReservedWords { "abstract", "become", "box", - "customize", "do", "final", "macro", diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt index 8ef035f042b..d1b14e9d8c7 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt @@ -15,6 +15,8 @@ import software.amazon.smithy.codegen.core.SymbolWriter.Factory import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId @@ -22,6 +24,7 @@ import software.amazon.smithy.model.traits.DeprecatedTrait import software.amazon.smithy.model.traits.DocumentationTrait import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.letIf @@ -429,7 +432,8 @@ class RustWriter private constructor( } /** - * Create an inline module. + * Create an inline module. Instead of being in a new file, inline modules are written as a `mod { ... }` block + * directly into the parent. * * Callers must take care to use [this] when writing to ensure code is written to the right place: * ```kotlin @@ -442,15 +446,19 @@ class RustWriter private constructor( * * The returned writer will inject any local imports into the module as needed. */ - fun withModule( - module: RustModule, + fun withInlineModule( + module: RustModule.LeafModule, moduleWriter: Writable, ): RustWriter { + check(module.isInline()) { + "Only inline modules may be used with `withInlineModule`: $module" + } // In Rust, modules must specify their own imports—they don't have access to the parent scope. // To easily handle this, create a new inner writer to collect imports, then dump it // into an inline module. val innerWriter = RustWriter(this.filename, "${this.namespace}::${module.name}", printWarning = false) moduleWriter(innerWriter) + module.documentation?.let { docs -> docs(docs) } module.rustMetadata.render(this) rustBlock("mod ${module.name}") { writeWithNoFormatting(innerWriter.toString()) @@ -460,33 +468,82 @@ class RustWriter private constructor( } /** - * Generate a wrapping if statement around a field. - * - * - If the field is optional, it will only be called if the field is present - * - If the field is an unboxed primitive, it will only be called if the field is non-zero - * + * Generate a wrapping if statement around a nullable value. + * The provided code block will only be called if the value is not `None`. */ - fun ifSet(shape: Shape, member: Symbol, outerField: String, block: RustWriter.(field: String) -> Unit) { + fun ifSome(member: Symbol, value: ValueExpression, block: RustWriter.(value: ValueExpression) -> Unit) { when { member.isOptional() -> { - val derefName = safeName("inner") - rustBlock("if let Some($derefName) = $outerField") { - block(derefName) + val innerValue = ValueExpression.Reference(safeName("inner")) + rustBlock("if let Some(${innerValue.name}) = ${value.asRef()}") { + block(innerValue) } } - shape is NumberShape -> rustBlock("if ${outerField.removePrefix("&")} != 0") { - block(outerField) + else -> this.block(value) + } + } + + /** + * Generate a wrapping if statement around a primitive field. + * The specified block will only be called if the field is not set to its default value - `0` for + * numbers, `false` for booleans. + */ + fun ifNotDefault(shape: Shape, variable: ValueExpression, block: RustWriter.(field: ValueExpression) -> Unit) { + when (shape) { + is FloatShape, is DoubleShape -> rustBlock("if ${variable.asValue()} != 0.0") { + block(variable) + } + + is NumberShape -> rustBlock("if ${variable.asValue()} != 0") { + block(variable) } - shape is BooleanShape -> rustBlock("if ${outerField.removePrefix("&")}") { - block(outerField) + is BooleanShape -> rustBlock("if ${variable.asValue()}") { + block(variable) } - else -> this.block(outerField) + else -> rustBlock("") { + this.block(variable) + } } } + /** + * Generate a wrapping if statement around a field. + * + * - If the field is optional, it will only be called if the field is present + * - If the field is an unboxed primitive, it will only be called if the field is non-zero + * + * # Example + * + * For a nullable structure shape (e.g. `Option`), the following code will be generated: + * + * ``` + * if let Some(v) = my_nullable_struct { + * /* {block(variable)} */ + * } + * ``` + * + * # Example + * + * For a non-nullable integer shape, the following code will be generated: + * + * ``` + * if my_int != 0 { + * /* {block(variable)} */ + * } + * ``` + */ + fun ifSet( + shape: Shape, + member: Symbol, + variable: ValueExpression, + block: RustWriter.(field: ValueExpression) -> Unit, + ) { + ifSome(member, variable) { inner -> ifNotDefault(shape, inner, block) } + } + fun listForEach( target: Shape, outerField: String, @@ -545,7 +602,8 @@ class RustWriter private constructor( inner class RustWriteableInjector : BiFunction { override fun apply(t: Any, u: String): String { @Suppress("UNCHECKED_CAST") - val func = t as? Writable ?: throw CodegenException("RustWriteableInjector.apply choked on non-function t ($t)") + val func = + t as? Writable ?: throw CodegenException("RustWriteableInjector.apply choked on non-function t ($t)") val innerWriter = RustWriter(filename, namespace, printWarning = false) func(innerWriter) innerWriter.dependencies.forEach { addDependency(it) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt index f1739324a7c..3a3434b8bc3 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt @@ -41,23 +41,24 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.ManifestCustom */ open class RustCrate( fileManifest: FileManifest, - symbolProvider: SymbolProvider, - /** - * For core modules like `input`, `output`, and `error`, we need to specify whether these modules should be public or - * private as well as any other metadata. [baseModules] enables configuring this. See [DefaultPublicModules]. - */ - baseModules: Map, + private val symbolProvider: SymbolProvider, coreCodegenConfig: CoreCodegenConfig, ) { private val inner = WriterDelegator(fileManifest, symbolProvider, RustWriter.factory(coreCodegenConfig.debugMode)) - private val modules: MutableMap = baseModules.toMutableMap() private val features: MutableSet = mutableSetOf() + // used to ensure we never create accidentally discard docs / incorrectly create modules with incorrect visibility + private var duplicateModuleWarningSystem: MutableMap = mutableMapOf() + /** * Write into the module that this shape is [locatedIn] */ fun useShapeWriter(shape: Shape, f: Writable) { - inner.useShapeWriter(shape, f) + val module = symbolProvider.toSymbol(shape).module() + check(!module.isInline()) { + "Cannot use useShapeWriter with inline modules—use [RustWriter.withInlineModule] instead" + } + withModule(symbolProvider.toSymbol(shape).module(), f) } /** @@ -94,14 +95,11 @@ open class RustCrate( requireDocs: Boolean = true, ) { injectInlineDependencies() - val modules = inner.writers.values.mapNotNull { it.module() }.filter { it != "lib" } - .mapNotNull { modules[it] } inner.finalize( settings, model, manifestCustomizations, libRsCustomizations, - modules, this.features.toList(), requireDocs, ) @@ -126,6 +124,17 @@ open class RustCrate( } } + private fun checkDups(module: RustModule.LeafModule) { + duplicateModuleWarningSystem[module.fullyQualifiedPath()]?.also { preexistingModule -> + check(module == preexistingModule) { + "Duplicate modules with differing properties were created! This will lead to non-deterministic behavior." + + "\n Previous module: $preexistingModule." + + "\n New module: $module" + } + } + duplicateModuleWarningSystem[module.fullyQualifiedPath()] = module + } + /** * Create a new module directly. The resulting module will be placed in `src/.rs` */ @@ -133,31 +142,22 @@ open class RustCrate( module: RustModule, moduleWriter: Writable, ): RustCrate { - val moduleName = module.name - modules[moduleName] = module - inner.useFileWriter("src/$moduleName.rs", "crate::$moduleName", moduleWriter) - return this - } - - /** - * Create a new non-root module directly. For example, if given the namespace `crate::foo::bar`, - * this will create `src/foo/bar.rs` with the contents from the given `moduleWriter`. - * Multiple calls to this with the same namespace are additive, so new code can be added - * by various customizations. - * - * Caution: this does not automatically add the required Rust `mod` statements to make this - * file an official part of the generated crate. This step needs to be done manually. - */ - fun withNonRootModule( - namespace: String, - moduleWriter: Writable, - ): RustCrate { - val parts = namespace.split("::") - require(parts.size > 2) { "Cannot create root modules using withNonRootModule" } - require(parts[0] == "crate") { "Namespace must start with crate::" } - - val fileName = "src/" + parts.filterIndexed { index, _ -> index > 0 }.joinToString("/") + ".rs" - inner.useFileWriter(fileName, namespace, moduleWriter) + when (module) { + is RustModule.LibRs -> lib { moduleWriter(this) } + is RustModule.LeafModule -> { + checkDups(module) + // Create a dependency which adds the mod statement for this module. This will be added to the writer + // so that _usage_ of this module will generate _exactly one_ `mod ` with the correct modifiers. + val modStatement = RuntimeType.forInlineFun("mod_" + module.fullyQualifiedPath(), module.parent) { + module.renderModStatement(this) + } + val path = module.fullyQualifiedPath().split("::").drop(1).joinToString("/") + inner.useFileWriter("src/$path.rs", module.fullyQualifiedPath()) { writer -> + moduleWriter(writer) + writer.addDependency(modStatement.dependency) + } + } + } return this } @@ -176,19 +176,11 @@ val OperationsModule = RustModule.public("operation", documentation = "All opera val ModelsModule = RustModule.public("model", documentation = "Data structures used by operation inputs/outputs.") val InputsModule = RustModule.public("input", documentation = "Input structures for operations.") val OutputsModule = RustModule.public("output", documentation = "Output structures for operations.") -val ConfigModule = RustModule.public("config", documentation = "Client configuration.") -/** - * Allowlist of modules that will be exposed publicly in generated crates - */ -val DefaultPublicModules = setOf( - ErrorsModule, - OperationsModule, - ModelsModule, - InputsModule, - OutputsModule, - ConfigModule, -).associateBy { it.name } +val UnconstrainedModule = + RustModule.private("unconstrained", "Unconstrained types for constrained shapes.") +val ConstrainedModule = + RustModule.private("constrained", "Constrained types for constrained shapes.") /** * Finalize all the writers by: @@ -200,12 +192,11 @@ fun WriterDelegator.finalize( model: Model, manifestCustomizations: ManifestCustomizations, libRsCustomizations: List, - modules: List, features: List, requireDocs: Boolean, ) { this.useFileWriter("src/lib.rs", "crate::lib") { - LibRsGenerator(settings, model, modules, libRsCustomizations, requireDocs).render(it) + LibRsGenerator(settings, model, libRsCustomizations, requireDocs).render(it) } val cargoDependencies = mergeDependencyFeatures( this.dependencies.map { RustDependency.fromSymbolDependency(it) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index baabdc7b077..ebaa19856ca 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -16,12 +16,14 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CratesIo import software.amazon.smithy.rust.codegen.core.rustlang.DependencyLocation import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency +import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency.Companion.constrained import software.amazon.smithy.rust.codegen.core.rustlang.Local import software.amazon.smithy.rust.codegen.core.rustlang.RustDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.asType import software.amazon.smithy.rust.codegen.core.rustlang.rustInlineTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.util.orNull @@ -256,14 +258,14 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n ) } + fun ConstrainedTrait() = constrained().asType().member("Constrained") + fun MaybeConstrained() = constrained().asType().member("MaybeConstrained") + fun ProtocolTestHelper(runtimeConfig: RuntimeConfig, func: String): RuntimeType = RuntimeType( func, CargoDependency.smithyProtocolTestHelpers(runtimeConfig), "aws_smithy_protocol_test", ) - fun ConstrainedTrait() = RuntimeType("Constrained", InlineDependency.constrained(), namespace = "crate::constrained") - fun MaybeConstrained() = RuntimeType("MaybeConstrained", InlineDependency.constrained(), namespace = "crate::constrained") - val http = CargoDependency.Http.toType() fun Http(path: String): RuntimeType = RuntimeType(name = path, dependency = CargoDependency.Http, namespace = "http") @@ -313,7 +315,7 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n fun forInlineFun(name: String, module: RustModule, func: Writable) = RuntimeType( name = name, dependency = InlineDependency(name, module, listOf(), func), - namespace = "crate::${module.name}", + namespace = module.fullyQualifiedPath(), ) fun parseResponse(runtimeConfig: RuntimeConfig) = RuntimeType( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt index 1ca4c2ebf25..d2e66a87aa9 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt @@ -38,6 +38,7 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumDefinition import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait @@ -69,24 +70,6 @@ data class SymbolVisitorConfig( val nullabilityCheckMode: CheckMode, ) -/** - * Container type for the file a symbol should be written to - * - * Downstream code uses symbol location to determine which file to use acquiring a writer - */ -data class SymbolLocation(val namespace: String) { - val filename = "$namespace.rs" -} - -val Models = SymbolLocation(ModelsModule.name) -val Errors = SymbolLocation(ErrorsModule.name) -val Operations = SymbolLocation(OperationsModule.name) -val Serializers = SymbolLocation("serializer") -val Inputs = SymbolLocation(InputsModule.name) -val Outputs = SymbolLocation(OutputsModule.name) -val Unconstrained = SymbolLocation("unconstrained") -val Constrained = SymbolLocation("constrained") - /** * Make the Rust type of a symbol optional (hold `Option`) * @@ -152,15 +135,16 @@ fun Symbol.mapRustType(f: (RustType) -> RustType): Symbol { } /** Set the symbolLocation for this symbol builder */ -fun Symbol.Builder.locatedIn(symbolLocation: SymbolLocation): Symbol.Builder { +fun Symbol.Builder.locatedIn(rustModule: RustModule.LeafModule): Symbol.Builder { val currentRustType = this.build().rustType() check(currentRustType is RustType.Opaque) { "Only `Opaque` can have their namespace updated" } - val newRustType = currentRustType.copy(namespace = "crate::${symbolLocation.namespace}") - return this.definitionFile("src/${symbolLocation.filename}") - .namespace("crate::${symbolLocation.namespace}", "::") + val newRustType = currentRustType.copy(namespace = rustModule.fullyQualifiedPath()) + return this.definitionFile(rustModule.definitionFile()) + .namespace(rustModule.fullyQualifiedPath(), "::") .rustType(newRustType) + .module(rustModule) } /** @@ -274,7 +258,7 @@ open class SymbolVisitor( override fun stringShape(shape: StringShape): Symbol { return if (shape.hasTrait()) { val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) - symbolBuilder(shape, rustType).locatedIn(Models).build() + symbolBuilder(shape, rustType).locatedIn(ModelsModule).build() } else { simpleShape(shape) } @@ -325,7 +309,7 @@ open class SymbolVisitor( .replaceFirstChar { it.uppercase() }, ), ) - .locatedIn(Operations) + .locatedIn(OperationsModule) .build() } @@ -346,16 +330,16 @@ open class SymbolVisitor( } val builder = symbolBuilder(shape, RustType.Opaque(name)) return when { - isError -> builder.locatedIn(Errors) - isInput -> builder.locatedIn(Inputs) - isOutput -> builder.locatedIn(Outputs) - else -> builder.locatedIn(Models) + isError -> builder.locatedIn(ErrorsModule) + isInput -> builder.locatedIn(InputsModule) + isOutput -> builder.locatedIn(OutputsModule) + else -> builder.locatedIn(ModelsModule) }.build() } override fun unionShape(shape: UnionShape): Symbol { val name = shape.contextName(serviceShape).toPascalCase() - val builder = symbolBuilder(shape, RustType.Opaque(name)).locatedIn(Models) + val builder = symbolBuilder(shape, RustType.Opaque(name)).locatedIn(ModelsModule) return builder.build() } @@ -399,13 +383,15 @@ fun symbolBuilder(shape: Shape?, rustType: RustType): Symbol.Builder { fun handleOptionality(symbol: Symbol, member: MemberShape, nullableIndex: NullableIndex, nullabilityCheckMode: CheckMode): Symbol = symbol.letIf(nullableIndex.isMemberNullable(member, nullabilityCheckMode)) { symbol.makeOptional() } -// TODO(chore): Move this to a useful place private const val RUST_TYPE_KEY = "rusttype" +private const val RUST_MODULE_KEY = "rustmodule" private const val SHAPE_KEY = "shape" private const val SYMBOL_DEFAULT = "symboldefault" private const val RENAMED_FROM_KEY = "renamedfrom" fun Symbol.Builder.rustType(rustType: RustType): Symbol.Builder = this.putProperty(RUST_TYPE_KEY, rustType) +fun Symbol.Builder.module(module: RustModule.LeafModule): Symbol.Builder = this.putProperty(RUST_MODULE_KEY, module) +fun Symbol.module(): RustModule.LeafModule = this.expectProperty(RUST_MODULE_KEY, RustModule.LeafModule::class.java) fun Symbol.Builder.renamedFrom(name: String): Symbol.Builder { return this.putProperty(RENAMED_FROM_KEY, name) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt index 89bae1ea493..df39a50841a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt @@ -14,6 +14,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.asArgument import software.amazon.smithy.rust.codegen.core.rustlang.asOptional @@ -36,7 +37,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault import software.amazon.smithy.rust.codegen.core.smithy.defaultValue import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.locatedIn import software.amazon.smithy.rust.codegen.core.smithy.makeOptional +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.util.dq @@ -53,12 +56,12 @@ fun builderSymbolFn(symbolProvider: RustSymbolProvider): (StructureShape) -> Sym fun StructureShape.builderSymbol(symbolProvider: RustSymbolProvider): Symbol { val structureSymbol = symbolProvider.toSymbol(this) val builderNamespace = RustReservedWords.escapeIfNeeded(structureSymbol.name.toSnakeCase()) - val rustType = RustType.Opaque("Builder", "${structureSymbol.namespace}::$builderNamespace") + val module = RustModule.new(builderNamespace, visibility = Visibility.PUBLIC, parent = structureSymbol.module(), inline = true) + val rustType = RustType.Opaque("Builder", module.fullyQualifiedPath()) return Symbol.builder() .rustType(rustType) .name(rustType.name) - .namespace(rustType.namespace, "::") - .definitionFile(structureSymbol.definitionFile) + .locatedIn(module) .build() } @@ -112,7 +115,7 @@ class BuilderGenerator( val symbol = symbolProvider.toSymbol(shape) writer.docs("See #D.", symbol) val segments = shape.builderSymbol(symbolProvider).namespace.split("::") - writer.withModule(RustModule.public(segments.last())) { + writer.withInlineModule(shape.builderSymbol(symbolProvider).module()) { renderBuilder(this) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt index 0acfe2641d4..95a46649eaa 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt @@ -141,7 +141,7 @@ open class EnumGenerator( } docs("Returns all the `&str` representations of the enum members.") - rustBlock("pub fn $Values() -> &'static [&'static str]") { + rustBlock("pub const fn $Values() -> &'static [&'static str]") { withBlock("&[", "]") { val memberList = sortedMembers.joinToString(", ") { it.value.dq() } rust(memberList) @@ -198,7 +198,7 @@ open class EnumGenerator( } rust("/// Returns all the `&str` values of the enum members.") - rustBlock("pub fn $Values() -> &'static [&'static str]") { + rustBlock("pub const fn $Values() -> &'static [&'static str]") { withBlock("&[", "]") { val memberList = sortedMembers.joinToString(", ") { it.value.doubleQuote() } write(memberList) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/LibRsGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/LibRsGenerator.kt index 3be71980804..c15d545e6b6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/LibRsGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/LibRsGenerator.kt @@ -7,7 +7,6 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import software.amazon.smithy.model.Model import software.amazon.smithy.model.traits.DocumentationTrait -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.containerDocs import software.amazon.smithy.rust.codegen.core.rustlang.escape @@ -33,7 +32,6 @@ typealias LibRsCustomization = NamedSectionGenerator class LibRsGenerator( private val settings: CoreRustSettings, private val model: Model, - private val modules: List, private val customizations: List, private val requireDocs: Boolean, ) { @@ -66,7 +64,6 @@ class LibRsGenerator( // TODO(docs): Automated feature documentation } - modules.forEach { it.render(writer) } customizations.forEach { it.section(LibRsSection.Body(model))(writer) } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt index 4303fa1a688..071a5bd89a3 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt @@ -25,6 +25,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.Std import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.mapRustType +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.REDACTION import software.amazon.smithy.rust.codegen.core.util.dq @@ -144,8 +145,8 @@ class ErrorGenerator( if (it.shouldRedact(model)) { write("""write!(f, ": {}", $REDACTION)?;""") } else { - ifSet(it, symbolProvider.toSymbol(it), "&self.message") { field -> - write("""write!(f, ": {}", $field)?;""") + ifSet(it, symbolProvider.toSymbol(it), ValueExpression.Reference("&self.message")) { field -> + write("""write!(f, ": {}", ${field.asRef()})?;""") } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt index 219784eff34..8ff3880c174 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators.http import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.shapes.BlobShape @@ -29,8 +30,8 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.asOptional -import software.amazon.smithy.rust.codegen.core.rustlang.autoDeref import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock @@ -45,12 +46,14 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedSectionGen import software.amazon.smithy.rust.codegen.core.smithy.customize.Section import software.amazon.smithy.rust.codegen.core.smithy.generators.OperationBuildError import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError +import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.makeOptional import software.amazon.smithy.rust.codegen.core.smithy.mapRustType import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.dq @@ -80,6 +83,10 @@ enum class HttpMessageType { sealed class HttpBindingSection(name: String) : Section(name) { data class BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders(val variableName: String, val shape: MapShape) : HttpBindingSection("BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders") + + data class BeforeRenderingHeaderValue(var context: HttpBindingGenerator.HeaderValueSerializationContext) : + HttpBindingSection("BeforeRenderingHeaderValue") + data class AfterDeserializingIntoAHashMapOfHttpPrefixHeaders(val memberShape: MemberShape) : HttpBindingSection("AfterDeserializingIntoAHashMapOfHttpPrefixHeaders") } @@ -299,6 +306,7 @@ class HttpBindingGenerator( "error_symbol" to errorSymbol, ) } + HttpMessageType.REQUEST -> { rust("let body_str = std::str::from_utf8(body)?;") } @@ -319,6 +327,7 @@ class HttpBindingGenerator( rust("Ok(body_str.to_string())") } } + is BlobShape -> rust( "Ok(#T::new(body))", symbolProvider.toSymbol(targetShape), @@ -401,6 +410,7 @@ class HttpBindingGenerator( }) """, ) + is RustType.HashSet -> rust( """ @@ -411,6 +421,7 @@ class HttpBindingGenerator( }) """, ) + else -> { if (targetShape is ListShape) { // This is a constrained list shape and we must therefore be generating a server SDK. @@ -449,7 +460,9 @@ class HttpBindingGenerator( */ // Rename here technically not required, operations and members cannot be renamed. private fun fnName(operationShape: OperationShape, binding: HttpBindingDescriptor) = - "${operationShape.id.getName(service).toSnakeCase()}_${binding.member.container.name.toSnakeCase()}_${binding.memberName.toSnakeCase()}" + "${ + operationShape.id.getName(service).toSnakeCase() + }_${binding.member.container.name.toSnakeCase()}_${binding.memberName.toSnakeCase()}" /** * Returns a function to set headers on an HTTP message for the given [shape]. @@ -467,6 +480,7 @@ class HttpBindingGenerator( // Only a single structure member can be bound by `httpPrefixHeaders`, hence the `getOrNull(0)`. HttpMessageType.REQUEST -> index.getRequestBindings(shape, HttpLocation.HEADER) to index.getRequestBindings(shape, HttpLocation.PREFIX_HEADERS).getOrNull(0) + HttpMessageType.RESPONSE -> index.getResponseBindings(shape, HttpLocation.HEADER) to index.getResponseBindings(shape, HttpLocation.PREFIX_HEADERS).getOrNull(0) } @@ -517,50 +531,135 @@ class HttpBindingGenerator( check(httpBinding.location == HttpLocation.HEADER) val memberShape = httpBinding.member val targetShape = model.expectShape(memberShape.target) - val memberSymbol = symbolProvider.toSymbol(memberShape) val memberName = symbolProvider.toMemberName(memberShape) - ifSet(targetShape, memberSymbol, "&input.$memberName") { field -> - listForEach(targetShape, field) { innerField, targetId -> - val innerMemberType = model.expectShape(targetId) - if (innerMemberType.isPrimitive()) { - val encoder = CargoDependency.smithyTypes(runtimeConfig).toType().member("primitive::Encoder") - rust("let mut encoder = #T::from(${autoDeref(innerField)});", encoder) - } - val formatted = headerFmtFun( - this, - innerMemberType, - memberShape, - innerField, - isListHeader = targetShape is CollectionShape, - ) - val safeName = safeName("formatted") - write("let $safeName = $formatted;") - rustBlock("if !$safeName.is_empty()") { - rustTemplate( - """ - let header_value = $safeName; - let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| { - #{invalid_field_error:W} - })?; - builder = builder.header("${httpBinding.locationName}", header_value); - """, - "invalid_field_error" to OperationBuildError(runtimeConfig).invalidField(memberName) { - rust( - """ - format!( - "`{}` cannot be used as a header value: {}", - &${memberShape.redactIfNecessary(model, "header_value")}, - err - ) - """, - ) - }, + val headerName = httpBinding.locationName + val timestampFormat = + index.determineTimestampFormat(memberShape, HttpBinding.Location.HEADER, defaultTimestampFormat) + val renderErrorMessage = { headerValueVariableName: String -> + OperationBuildError(runtimeConfig).invalidField(memberName) { + rust( + """ + format!( + "`{}` cannot be used as a header value: {}", + &${memberShape.redactIfNecessary(model, headerValueVariableName)}, + err ) - } + """, + ) + } + } + + val memberSymbol = symbolProvider.toSymbol(memberShape) + // If a header is of a primitive type and required (e.g. `bool`), we do not serialize it on the + // wire if it's set to the default value for that primitive type (e.g. `false` for `bool`). + // If the header is optional, instead, we want to serialize it if it has been set by the user to the + // default value for that primitive type (e.g. `Some(false)` for an `Option` header). + // If a header is multivalued, we always want to serialize its primitive members, regardless of their + // values. + val serializePrimitiveValuesIfDefault = memberSymbol.isOptional() || (targetShape is CollectionShape) + ifSome(memberSymbol, ValueExpression.Reference("&input.$memberName")) { variableName -> + if (targetShape is CollectionShape) { + renderMultiValuedHeader( + model, + headerName, + variableName, + targetShape, + timestampFormat, + renderErrorMessage, + ) + } else { + renderHeaderValue( + headerName, + variableName, + targetShape, + false, + timestampFormat, + renderErrorMessage, + serializePrimitiveValuesIfDefault, + ) } } } + private fun RustWriter.renderMultiValuedHeader( + model: Model, + headerName: String, + value: ValueExpression, + shape: CollectionShape, + timestampFormat: TimestampFormatTrait.Format, + renderErrorMessage: (String) -> Writable, + ) { + val loopVariable = ValueExpression.Reference(safeName("inner")) + rustBlock("for ${loopVariable.name} in ${value.asRef()}") { + this.renderHeaderValue( + headerName, + loopVariable, + model.expectShape(shape.member.target), + isMultiValuedHeader = true, + timestampFormat, + renderErrorMessage, + serializeIfDefault = true, + ) + } + } + + data class HeaderValueSerializationContext( + /** Expression representing the value to write to the JsonValueWriter */ + var valueExpression: ValueExpression, + /** Path in the JSON to get here, used for errors */ + val shape: Shape, + ) + + private fun RustWriter.renderHeaderValue( + headerName: String, + value: ValueExpression, + shape: Shape, + isMultiValuedHeader: Boolean, + timestampFormat: TimestampFormatTrait.Format, + renderErrorMessage: (String) -> Writable, + serializeIfDefault: Boolean, + ) { + val context = HeaderValueSerializationContext(value, shape) + for (customization in customizations) { + customization.section( + HttpBindingSection.BeforeRenderingHeaderValue(context), + )(this) + } + + val block: RustWriter.(value: ValueExpression) -> Unit = { variableName -> + if (shape.isPrimitive()) { + val encoder = CargoDependency.smithyTypes(runtimeConfig).toType().member("primitive::Encoder") + rust("let mut encoder = #T::from(${variableName.asValue()});", encoder) + } + val formatted = headerFmtFun( + this, + shape, + timestampFormat, + context.valueExpression.name, + isMultiValuedHeader = isMultiValuedHeader, + ) + val safeName = safeName("formatted") + rustTemplate( + """ + let $safeName = $formatted; + if !$safeName.is_empty() { + let header_value = $safeName; + let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| { + #{invalid_field_error:W} + })?; + builder = builder.header("$headerName", header_value); + } + """, + "invalid_field_error" to renderErrorMessage("header_value"), + ) + } + if (serializeIfDefault) { + block(context.valueExpression) + } else { + ifNotDefault(context.shape, context.valueExpression, block) + } + } + private fun RustWriter.renderPrefixHeader(httpBinding: HttpBinding) { check(httpBinding.location == HttpLocation.PREFIX_HEADERS) val memberShape = httpBinding.member @@ -568,21 +667,31 @@ class HttpBindingGenerator( val memberSymbol = symbolProvider.toSymbol(memberShape) val memberName = symbolProvider.toMemberName(memberShape) val valueTargetShape = model.expectShape(targetShape.value.target) + val timestampFormat = + index.determineTimestampFormat(memberShape, HttpBinding.Location.HEADER, defaultTimestampFormat) - ifSet(targetShape, memberSymbol, "&input.$memberName") { field -> + ifSet(targetShape, memberSymbol, ValueExpression.Reference("&input.$memberName")) { local -> for (customization in customizations) { customization.section( - HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders(field, targetShape), + HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders(local.name, targetShape), )(this) } rustTemplate( """ - for (k, v) in $field { + for (k, v) in ${local.asRef()} { use std::str::FromStr; let header_name = http::header::HeaderName::from_str(&format!("{}{}", "${httpBinding.locationName}", &k)).map_err(|err| { #{invalid_header_name:W} })?; - let header_value = ${headerFmtFun(this, valueTargetShape, memberShape, "v", isListHeader = false)}; + let header_value = ${ + headerFmtFun( + this, + valueTargetShape, + timestampFormat, + "v", + isMultiValuedHeader = false, + ) + }; let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| { #{invalid_header_value:W} })?; @@ -611,10 +720,16 @@ class HttpBindingGenerator( /** * Format [member] when used as an HTTP header. */ - private fun headerFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String, isListHeader: Boolean): String { + private fun headerFmtFun( + writer: RustWriter, + target: Shape, + timestampFormat: TimestampFormatTrait.Format, + targetName: String, + isMultiValuedHeader: Boolean, + ): String { fun quoteValue(value: String): String { // Timestamp shapes are not quoted in header lists - return if (isListHeader && !target.isTimestampShape) { + return if (isMultiValuedHeader && !target.isTimestampShape) { val quoteFn = writer.format(headerUtil.member("quote_header_value")) "$quoteFn($value)" } else { @@ -630,18 +745,20 @@ class HttpBindingGenerator( quoteValue("$targetName.as_str()") } } + target.isTimestampShape -> { - val timestampFormat = - index.determineTimestampFormat(member, HttpBinding.Location.HEADER, defaultTimestampFormat) val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat) quoteValue("$targetName.fmt(${writer.format(timestampFormatType)})?") } + target.isListShape || target.isMemberShape -> { throw IllegalArgumentException("lists should be handled at a higher level") } + target.isPrimitive() -> { "encoder.encode()" } + else -> throw CodegenException("unexpected shape: $target") } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt index 0a1421c9385..e5d0a74c8ee 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt @@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectMember import software.amazon.smithy.rust.codegen.core.util.hasTrait @@ -193,8 +194,12 @@ class RequestBindingGenerator( val memberName = symbolProvider.toMemberName(memberShape) val targetShape = model.expectShape(memberShape.target, MapShape::class.java) val stringFormatter = RuntimeType.QueryFormat(runtimeConfig, "fmt_string") - ifSet(model.expectShape(param.member.target), memberSymbol, "&_input.$memberName") { field -> - rustBlock("for (k, v) in $field") { + ifSet( + model.expectShape(param.member.target), + memberSymbol, + ValueExpression.Reference("&_input.$memberName"), + ) { value -> + rustBlock("for (k, v) in ${value.asRef()}") { // if v is a list, generate another level of iteration listForEach(model.expectShape(targetShape.value.target), "v") { innerField, _ -> rustBlock("if !protected_params.contains(&k.as_str())") { @@ -236,9 +241,9 @@ class RequestBindingGenerator( paramList(target, derefName, param, writer, memberShape) } else { - ifSet(target, memberSymbol, "&_input.$memberName") { field -> + ifSet(target, memberSymbol, ValueExpression.Reference("&_input.$memberName")) { field -> // if `param` is a list, generate another level of iteration - paramList(target, field, param, writer, memberShape) + paramList(target, field.name, param, writer, memberShape) } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 1d7a099272d..4693cc31f78 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -334,7 +334,7 @@ class JsonParserGenerator( .map(#{NumberType}::try_from) .transpose()? """, - "NumberType" to symbolProvider.toSymbol(target), + "NumberType" to returnSymbolToParse(target).symbol, *codegenScope, ) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index 82be9e78188..c7805fc4681 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -7,14 +7,20 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.ByteShape import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.DocumentShape +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.shapes.LongShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.ShortShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape @@ -22,7 +28,6 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.TimestampFormatTrait.Format.EPOCH_SECONDS import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock @@ -42,7 +47,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.serializeFunctionName -import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectTrait @@ -54,16 +58,24 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape */ sealed class JsonSerializerSection(name: String) : Section(name) { /** Mutate the server error object prior to finalization. Eg: this can be used to inject `__type` to record the error type. */ - data class ServerError(val structureShape: StructureShape, val jsonObject: String) : JsonSerializerSection("ServerError") + data class ServerError(val structureShape: StructureShape, val jsonObject: String) : + JsonSerializerSection("ServerError") - /** Mutate a map prior to it being serialized. **/ - data class BeforeIteratingOverMap(val shape: MapShape, val valueExpression: ValueExpression) : JsonSerializerSection("BeforeIteratingOverMap") + /** Manipulate the serializer context for a map prior to it being serialized. **/ + data class BeforeIteratingOverMap(val shape: MapShape, val context: JsonSerializerGenerator.Context) : + JsonSerializerSection("BeforeIteratingOverMap") + + /** Manipulate the serializer context for a non-null member prior to it being serialized. **/ + data class BeforeSerializingNonNullMember(val shape: Shape, val context: JsonSerializerGenerator.MemberContext) : + JsonSerializerSection("BeforeSerializingNonNullMember") /** Mutate the input object prior to finalization. */ - data class InputStruct(val structureShape: StructureShape, val jsonObject: String) : JsonSerializerSection("InputStruct") + data class InputStruct(val structureShape: StructureShape, val jsonObject: String) : + JsonSerializerSection("InputStruct") /** Mutate the output object prior to finalization. */ - data class OutputStruct(val structureShape: StructureShape, val jsonObject: String) : JsonSerializerSection("OutputStruct") + data class OutputStruct(val structureShape: StructureShape, val jsonObject: String) : + JsonSerializerSection("OutputStruct") } /** @@ -78,20 +90,20 @@ class JsonSerializerGenerator( private val jsonName: (MemberShape) -> String, private val customizations: List = listOf(), ) : StructuredDataSerializerGenerator { - private data class Context( + data class Context( /** Expression that retrieves a JsonValueWriter from either a JsonObjectWriter or JsonArrayWriter */ val writerExpression: String, /** Expression representing the value to write to the JsonValueWriter */ - val valueExpression: ValueExpression, + var valueExpression: ValueExpression, /** Path in the JSON to get here, used for errors */ val shape: T, ) - private data class MemberContext( + data class MemberContext( /** Expression that retrieves a JsonValueWriter from either a JsonObjectWriter or JsonArrayWriter */ val writerExpression: String, /** Expression representing the value to write to the JsonValueWriter */ - val valueExpression: ValueExpression, + var valueExpression: ValueExpression, val shape: MemberShape, /** Whether to serialize null values if the type is optional */ val writeNulls: Boolean = false, @@ -144,7 +156,7 @@ class JsonSerializerGenerator( } // Specialized since it holds a JsonObjectWriter expression rather than a JsonValueWriter - private data class StructContext( + data class StructContext( /** Name of the JsonObjectWriter */ val objectName: String, /** Name of the variable that holds the struct */ @@ -337,8 +349,16 @@ class JsonSerializerGenerator( if (symbolProvider.toSymbol(context.shape).isOptional()) { safeName().also { local -> rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") { - val innerContext = context.copy(valueExpression = ValueExpression.Reference(local)) - serializeMemberValue(innerContext, targetShape) + context.valueExpression = ValueExpression.Reference(local) + for (customization in customizations) { + customization.section( + JsonSerializerSection.BeforeSerializingNonNullMember( + targetShape, + context, + ), + )(this) + } + serializeMemberValue(context, targetShape) } if (context.writeNulls) { rustBlock("else") { @@ -347,6 +367,12 @@ class JsonSerializerGenerator( } } } else { + for (customization in customizations) { + customization.section(JsonSerializerSection.BeforeSerializingNonNullMember(targetShape, context))( + this, + ) + } + with(serializerUtil) { ignoreZeroValues(context.shape, context.valueExpression) { serializeMemberValue(context, targetShape) @@ -363,10 +389,9 @@ class JsonSerializerGenerator( is StringShape -> rust("$writer.string(${value.name}.as_str());") is BooleanShape -> rust("$writer.boolean(${value.asValue()});") is NumberShape -> { - val numberType = when (symbolProvider.toSymbol(target).rustType()) { - is RustType.Float -> "Float" - // NegInt takes an i64 while PosInt takes u64. We need this to be signed here - is RustType.Integer -> "NegInt" + val numberType = when (target) { + is IntegerShape, is ByteShape, is LongShape, is ShortShape -> "NegInt" + is DoubleShape, is FloatShape -> "Float" else -> throw IllegalStateException("unreachable") } rust( @@ -374,10 +399,12 @@ class JsonSerializerGenerator( smithyTypes.member("Number"), ) } + is BlobShape -> rust( "$writer.string_unchecked(&#T(${value.asRef()}));", RuntimeType.Base64Encode(runtimeConfig), ) + is TimestampShape -> { val timestampFormat = httpBindingResolver.timestampFormat(context.shape, HttpLocation.DOCUMENT, EPOCH_SECONDS) @@ -388,18 +415,23 @@ class JsonSerializerGenerator( "ConvertInto" to typeConversionGenerator.convertViaInto(target), ) } + is CollectionShape -> jsonArrayWriter(context) { arrayName -> serializeCollection(Context(arrayName, value, target)) } + is MapShape -> jsonObjectWriter(context) { objectName -> serializeMap(Context(objectName, value, target)) } + is StructureShape -> jsonObjectWriter(context) { objectName -> serializeStructure(StructContext(objectName, value.asRef(), target)) } + is UnionShape -> jsonObjectWriter(context) { objectName -> serializeUnion(Context(objectName, value, target)) } + is DocumentShape -> rust("$writer.document(${value.asRef()});") else -> TODO(target.toString()) } @@ -432,7 +464,9 @@ class JsonSerializerGenerator( val keyName = safeName("key") val valueName = safeName("value") for (customization in customizations) { - customization.section(JsonSerializerSection.BeforeIteratingOverMap(context.shape, context.valueExpression))(this) + customization.section(JsonSerializerSection.BeforeIteratingOverMap(context.shape, context))( + this, + ) } rustBlock("for ($keyName, $valueName) in ${context.valueExpression.asRef()}") { val keyExpression = "$keyName.as_str()" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/SerializerUtil.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/SerializerUtil.kt index 98bdc80aa0b..78012b1d6d4 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/SerializerUtil.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/SerializerUtil.kt @@ -6,11 +6,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes.BooleanShape -import software.amazon.smithy.model.shapes.DoubleShape -import software.amazon.smithy.model.shapes.FloatShape import software.amazon.smithy.model.shapes.MemberShape -import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable @@ -18,16 +14,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock class SerializerUtil(private val model: Model) { fun RustWriter.ignoreZeroValues(shape: MemberShape, value: ValueExpression, inner: Writable) { - val expr = when (model.expectShape(shape.target)) { - is FloatShape, is DoubleShape -> "${value.asValue()} != 0.0" - is NumberShape -> "${value.asValue()} != 0" - is BooleanShape -> value.asValue() - else -> null - } - - if (expr == null || - // Required shapes should always be serialized - // See https://github.com/awslabs/smithy-rs/issues/230 and https://github.com/aws/aws-sdk-go-v2/pull/1129 + // Required shapes should always be serialized + // See https://github.com/awslabs/smithy-rs/issues/230 and https://github.com/aws/aws-sdk-go-v2/pull/1129 + if ( shape.isRequired || // Zero values are always serialized in lists and collections, this only applies to structures model.expectShape(shape.container) !is StructureShape @@ -36,9 +25,7 @@ class SerializerUtil(private val model: Model) { inner(this) } } else { - rustBlock("if $expr") { - inner(this) - } + this.ifNotDefault(model.expectShape(shape.target), value) { inner(this) } } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/ValueExpression.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/ValueExpression.kt index 6ab3aea1e34..00bc8ba74c7 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/ValueExpression.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/ValueExpression.kt @@ -22,4 +22,6 @@ sealed class ValueExpression { is Reference -> name is Value -> "&$name" } + + override fun toString(): String = this.name } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt index c5bc50ed42e..eb250a05f71 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt @@ -18,7 +18,6 @@ import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.traits.EnumDefinition import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustDependency -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.raw @@ -85,6 +84,9 @@ object TestWorkspace { version = "0.0.1" """.trimIndent(), ) + // ensure there at least an empty lib.rs file to avoid broken crates + newProject.resolve("src").mkdirs() + newProject.resolve("src/lib.rs").writeText("") subprojects.add(newProject.name) generate() return newProject @@ -190,14 +192,6 @@ fun RustWriter.unitTest( return rustBlock("fn $name()", *args, block = block) } -val DefaultTestPublicModules = setOf( - RustModule.Error, - RustModule.Model, - RustModule.Input, - RustModule.Output, - RustModule.Config, -).associateBy { it.name } - /** * WriterDelegator used for test purposes * @@ -211,7 +205,6 @@ class TestWriterDelegator( RustCrate( fileManifest, symbolProvider, - DefaultTestPublicModules, codegenConfig, ) { val baseDir: Path = fileManifest.baseDir @@ -219,6 +212,8 @@ class TestWriterDelegator( fun printGeneratedFiles() { fileManifest.printGeneratedFiles() } + + fun generatedFiles() = fileManifest.files.map { baseDir.relativize(it) } } fun FileManifest.printGeneratedFiles() { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt index 5bc0744ba12..3fff4b78044 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt @@ -67,7 +67,7 @@ fun testRustSettings( private const val SmithyVersion = "1.0" fun String.asSmithyModel(sourceLocation: String? = null, smithyVersion: String = SmithyVersion): Model { - val processed = letIf(!this.startsWith("\$version")) { "\$version: ${smithyVersion.dq()}\n$it" } + val processed = letIf(!this.trimStart().startsWith("\$version")) { "\$version: ${smithyVersion.dq()}\n$it" } return Model.assembler().discoverModels().addUnparsedModel(sourceLocation ?: "test.smithy", processed).assemble() .unwrap() } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/InlineDependencyTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/InlineDependencyTest.kt index 1d8448c3060..938e5306ce9 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/InlineDependencyTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/InlineDependencyTest.kt @@ -5,10 +5,15 @@ package software.amazon.smithy.rust.codegen.core.rustlang +import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe import io.kotest.matchers.shouldNotBe import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import kotlin.io.path.pathString internal class InlineDependencyTest { private fun makeDep(name: String) = InlineDependency(name, RustModule.private("module")) { @@ -28,17 +33,66 @@ internal class InlineDependencyTest { @Test fun `locate dependencies from the inlineable module`() { val dep = InlineDependency.idempotencyToken() - val testWriter = RustWriter.root() - testWriter.addDependency(CargoDependency.FastRand) - testWriter.withModule(dep.module.copy(rustMetadata = RustMetadata(visibility = Visibility.PUBLIC))) { - dep.renderer(this) + val testProject = TestWorkspace.testProject() + testProject.unitTest { + rustTemplate( + """ + use #{idempotency}::uuid_v4; + let res = uuid_v4(0); + assert_eq!(res, "00000000-0000-4000-8000-000000000000"); + + """, + "idempotency" to dep.asType(), + ) } - testWriter.compileAndTest( - """ - use crate::idempotency_token::uuid_v4; - let res = uuid_v4(0); - assert_eq!(res, "00000000-0000-4000-8000-000000000000"); - """, - ) + testProject.compileAndTest() + } + + @Test + fun `nested dependency modules`() { + val a = RustModule.public("a") + val b = RustModule.public("b", parent = a) + val c = RustModule.public("c", parent = b) + val type = RuntimeType.forInlineFun("forty2", c) { + rust( + """ + pub fn forty2() -> usize { 42 } + """, + ) + } + val crate = TestWorkspace.testProject() + crate.lib { + unitTest("use_nested_module") { + rustTemplate("assert_eq!(42, #{forty2}())", "forty2" to type) + } + } + crate.compileAndTest() + val generatedFiles = crate.generatedFiles().map { it.pathString } + assert(generatedFiles.contains("src/a.rs")) { generatedFiles } + assert(generatedFiles.contains("src/a/b.rs")) { generatedFiles } + assert(generatedFiles.contains("src/a/b/c.rs")) { generatedFiles } + } + + @Test + fun `prevent the creation of duplicate modules`() { + val root = RustModule.private("parent") + // create a child module with no docs + val child1 = RustModule.private("child", parent = root) + val child2 = RustModule.public("child", parent = root) + val crate = TestWorkspace.testProject() + crate.withModule(child1) { } + shouldThrow { + crate.withModule(child2) {} + } + + shouldThrow { + // can't make one with docs when the old one had no docs + crate.withModule(RustModule.private("child", documentation = "docs", parent = root)) {} + } + + // but making an identical module is fine + val identicalChild = RustModule.private("child", parent = root) + crate.withModule(identicalChild) {} + identicalChild.fullyQualifiedPath() shouldBe "crate::parent::child" } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriterTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriterTest.kt index 4c18a23c2e5..b3a8e6529c8 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriterTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriterTest.kt @@ -27,7 +27,7 @@ class RustWriterTest { fun `inner modules correctly handle dependencies`() { val sut = RustWriter.forModule("parent") val requestBuilder = RuntimeType.HttpRequestBuilder - sut.withModule(RustModule.public("inner")) { + sut.withInlineModule(RustModule.new("inner", visibility = Visibility.PUBLIC, inline = true)) { rustBlock("fn build(builder: #T)", requestBuilder) { } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt index 110e7f35a98..fcf242dcc4b 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt @@ -13,10 +13,10 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.raw import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel @@ -100,20 +100,21 @@ class StructureGeneratorTest { @Test fun `generate structures with public fields`() { + val project = TestWorkspace.testProject() val provider = testSymbolProvider(model) - val writer = RustWriter.root() - writer.rust("##![allow(deprecated)]") - writer.withModule(RustModule.public("model")) { + + project.lib { rust("##![allow(deprecated)]") } + project.withModule(ModelsModule) { val innerGenerator = StructureGenerator(model, provider, this, inner) innerGenerator.render() } - writer.withModule(RustModule.public("structs")) { + project.withModule(RustModule.public("structs")) { val generator = StructureGenerator(model, provider, this, struct) generator.render() } // By putting the test in another module, it can't access the struct // fields if they are private - writer.withModule(RustModule.public("inline")) { + project.unitTest { raw("#[test]") rustBlock("fn test_public_fields()") { write( @@ -124,7 +125,7 @@ class StructureGeneratorTest { ) } } - writer.compileAndTest() + project.compileAndTest() } @Test @@ -179,15 +180,16 @@ class StructureGeneratorTest { nested2: Inner }""".asSmithyModel() val provider = testSymbolProvider(model) - val writer = RustWriter.root() - writer.docs("module docs") - Attribute.Custom("deny(missing_docs)").render(writer) - writer.withModule(RustModule.public("model")) { + val project = TestWorkspace.testProject(provider) + project.lib { + Attribute.Custom("deny(missing_docs)").render(this) + } + project.withModule(ModelsModule) { StructureGenerator(model, provider, this, model.lookup("com.test#Inner")).render() StructureGenerator(model, provider, this, model.lookup("com.test#MyStruct")).render() } - writer.compileAndTest() + project.compileAndTest() } @Test @@ -224,9 +226,9 @@ class StructureGeneratorTest { structure Qux {} """.asSmithyModel() val provider = testSymbolProvider(model) - val writer = RustWriter.root() - writer.rust("##![allow(deprecated)]") - writer.withModule(RustModule.public("model")) { + val project = TestWorkspace.testProject(provider) + project.lib { rust("##![allow(deprecated)]") } + project.withModule(ModelsModule) { StructureGenerator(model, provider, this, model.lookup("test#Foo")).render() StructureGenerator(model, provider, this, model.lookup("test#Bar")).render() StructureGenerator(model, provider, this, model.lookup("test#Baz")).render() @@ -234,7 +236,7 @@ class StructureGeneratorTest { } // turn on clippy to check the semver-compliant version of `since`. - writer.compileAndTest(clippy = true) + project.compileAndTest(runClippy = true) } @Test @@ -257,15 +259,15 @@ class StructureGeneratorTest { structure Bar {} """.asSmithyModel() val provider = testSymbolProvider(model) - val writer = RustWriter.root() - writer.rust("##![allow(deprecated)]") - writer.withModule(RustModule.public("model")) { + val project = TestWorkspace.testProject(provider) + project.lib { rust("##![allow(deprecated)]") } + project.withModule(ModelsModule) { StructureGenerator(model, provider, this, model.lookup("test#Nested")).render() StructureGenerator(model, provider, this, model.lookup("test#Foo")).render() StructureGenerator(model, provider, this, model.lookup("test#Bar")).render() } - writer.compileAndTest() + project.compileAndTest() } @Test diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGeneratorTest.kt index 043941bfe18..d2ea8b9de11 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGeneratorTest.kt @@ -8,9 +8,10 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import io.kotest.matchers.string.shouldContain import org.junit.jupiter.api.Test import software.amazon.smithy.codegen.core.SymbolProvider -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider @@ -112,16 +113,16 @@ class UnionGeneratorTest { @deprecated union Bar { x: Integer } """.asSmithyModel() - val provider: SymbolProvider = testSymbolProvider(model) - val writer = RustWriter.root() - writer.rust("##![allow(deprecated)]") - writer.withModule(RustModule.public("model")) { + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + project.lib { rust("##![allow(deprecated)]") } + project.withModule(ModelsModule) { UnionGenerator(model, provider, this, model.lookup("test#Nested")).render() UnionGenerator(model, provider, this, model.lookup("test#Foo")).render() UnionGenerator(model, provider, this, model.lookup("test#Bar")).render() } - writer.compileAndTest() + project.compileAndTest() } private fun generateUnion(modelSmithy: String, unionName: String = "MyUnion", unknownVariant: Boolean = true): RustWriter { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/CombinedErrorGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/CombinedErrorGeneratorTest.kt index d84e50b5cc5..30838e82da4 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/CombinedErrorGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/CombinedErrorGeneratorTest.kt @@ -7,7 +7,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators.error import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel @@ -50,7 +50,7 @@ class CombinedErrorGeneratorTest { @Test fun `generates combined error enums`() { val project = TestWorkspace.testProject(symbolProvider) - project.withModule(RustModule.public("error")) { + project.withModule(ErrorsModule) { listOf("FooException", "ComplexError", "InvalidGreeting", "Deprecated").forEach { model.lookup("error#$it").renderWithModelBuilder(model, symbolProvider, this) } @@ -90,8 +90,6 @@ class CombinedErrorGeneratorTest { """, ) - println("file:///${project.baseDir}/src/lib.rs") - println("file:///${project.baseDir}/src/error.rs") project.compileAndTest() } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/TopLevelErrorGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/TopLevelErrorGeneratorTest.kt index d407af3fe44..a3d8fc364a2 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/TopLevelErrorGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/TopLevelErrorGeneratorTest.kt @@ -15,7 +15,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator -import software.amazon.smithy.rust.codegen.core.testutil.DefaultTestPublicModules import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.generatePluginContext import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider @@ -75,7 +74,6 @@ internal class TopLevelErrorGeneratorTest { val rustCrate = RustCrate( pluginContext.fileManifest, symbolProvider, - DefaultTestPublicModules, codegenContext.settings.codegenConfig, ) diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index a49bbcce507..82187c2db61 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -69,6 +69,11 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> "aws.protocoltests.restjson.validation#RestJsonValidation", "rest_json_validation", extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """, ), + CodegenTest( + "aws.protocoltests.extras.restjson.validation#MalformedRangeValidation", "malformed_range_extras", + extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """, + imports = listOf("$commonModels/malformed-range-extras.smithy"), + ), CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"), CodegenTest( "aws.protocoltests.json#JsonProtocol", diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt index 0eac9dd1961..c9d73999f33 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt @@ -23,7 +23,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerEnumGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerServiceGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerStructureGenerator -import software.amazon.smithy.rust.codegen.server.smithy.DefaultServerPublicModules import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenVisitor import software.amazon.smithy.rust.codegen.server.smithy.ServerSymbolProviders @@ -98,7 +97,7 @@ class PythonServerCodegenVisitor( ) // Override `rustCrate` which carries the symbolProvider. - rustCrate = RustCrate(context.fileManifest, codegenContext.symbolProvider, DefaultServerPublicModules, settings.codegenConfig) + rustCrate = RustCrate(context.fileManifest, codegenContext.symbolProvider, settings.codegenConfig) // Override `protocolGenerator` which carries the symbolProvider. protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext) } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 4431b5f0a45..63fd61e1c9f 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -13,9 +13,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.Errors -import software.amazon.smithy.rust.codegen.core.smithy.Inputs -import software.amazon.smithy.rust.codegen.core.smithy.Outputs +import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule +import software.amazon.smithy.rust.codegen.core.smithy.InputsModule +import software.amazon.smithy.rust.codegen.core.smithy.OutputsModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.inputShape @@ -337,12 +337,12 @@ class PythonApplicationGenerator( ) writer.rust( """ - /// from $libName import ${Inputs.namespace} - /// from $libName import ${Outputs.namespace} + /// from $libName import ${InputsModule.name} + /// from $libName import ${OutputsModule.name} """.trimIndent(), ) if (operations.any { it.errors.isNotEmpty() }) { - writer.rust("""/// from $libName import ${Errors.namespace}""".trimIndent()) + writer.rust("""/// from $libName import ${ErrorsModule.name}""".trimIndent()) } writer.rust( """ @@ -396,8 +396,8 @@ class PythonApplicationGenerator( private fun OperationShape.signature(): String { val inputSymbol = symbolProvider.toSymbol(inputShape(model)) val outputSymbol = symbolProvider.toSymbol(outputShape(model)) - val inputT = "${Inputs.namespace}::${inputSymbol.name}" - val outputT = "${Outputs.namespace}::${outputSymbol.name}" + val inputT = "${InputsModule.name}::${inputSymbol.name}" + val outputT = "${OutputsModule.name}::${outputSymbol.name}" val operationName = symbolProvider.toSymbol(this).name.toSnakeCase() return "@app.$operationName\n/// def $operationName(input: $inputT, ctx: Context) -> $outputT" } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt index 92ff5faf77e..240d5830651 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.ServiceShape @@ -16,7 +17,7 @@ import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.traits.LengthTrait import software.amazon.smithy.rust.codegen.core.rustlang.RustType -import software.amazon.smithy.rust.codegen.core.smithy.Models +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.contextName @@ -58,7 +59,7 @@ class ConstrainedShapeSymbolProvider( check(shape is MapShape) val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) - return symbolBuilder(shape, rustType).locatedIn(Models).build() + return symbolBuilder(shape, rustType).locatedIn(ModelsModule).build() } override fun toSymbol(shape: Shape): Symbol { @@ -95,10 +96,10 @@ class ConstrainedShapeSymbolProvider( symbolBuilder(shape, RustType.Vec(inner.rustType())).addReference(inner).build() } } - is StringShape -> { + is StringShape, is IntegerShape -> { if (shape.isDirectlyConstrained(base)) { val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) - symbolBuilder(shape, rustType).locatedIn(Models).build() + symbolBuilder(shape, rustType).locatedIn(ModelsModule).build() } else { base.toSymbol(shape) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt index 119db964524..264edf545be 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt @@ -8,18 +8,23 @@ package software.amazon.smithy.rust.codegen.server.smithy import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustType -import software.amazon.smithy.rust.codegen.core.smithy.Models +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.contextName +import software.amazon.smithy.rust.codegen.core.smithy.locatedIn +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol @@ -60,27 +65,32 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilde class ConstraintViolationSymbolProvider( private val base: RustSymbolProvider, private val model: Model, - private val serviceShape: ServiceShape, private val publicConstrainedTypes: Boolean, + private val serviceShape: ServiceShape, ) : WrappingSymbolProvider(base) { private val constraintViolationName = "ConstraintViolation" + private val visibility = when (publicConstrainedTypes) { + true -> Visibility.PUBLIC + false -> Visibility.PUBCRATE + } + + private fun Shape.shapeModule() = RustModule.new( + // need to use the context name so we get the correct name for maps + name = RustReservedWords.escapeIfNeeded(this.contextName(serviceShape)).toSnakeCase(), + visibility = visibility, + parent = ModelsModule, + inline = true, + ) private fun constraintViolationSymbolForCollectionOrMapOrUnionShape(shape: Shape): Symbol { check(shape is CollectionShape || shape is MapShape || shape is UnionShape) - val symbol = base.toSymbol(shape) - val constraintViolationNamespace = - "${symbol.namespace.let { it.ifEmpty { "crate::${Models.namespace}" } }}::${ - RustReservedWords.escapeIfNeeded( - shape.contextName(serviceShape).toSnakeCase(), - ) - }" - val rustType = RustType.Opaque(constraintViolationName, constraintViolationNamespace) + val module = shape.shapeModule() + val rustType = RustType.Opaque(constraintViolationName, module.fullyQualifiedPath()) return Symbol.builder() .rustType(rustType) .name(rustType.name) - .namespace(rustType.namespace, "::") - .definitionFile(symbol.definitionFile) + .locatedIn(module) .build() } @@ -91,32 +101,28 @@ class ConstraintViolationSymbolProvider( is MapShape, is CollectionShape, is UnionShape -> { constraintViolationSymbolForCollectionOrMapOrUnionShape(shape) } + is StructureShape -> { val builderSymbol = shape.serverBuilderSymbol(base, pubCrate = !publicConstrainedTypes) - val namespace = builderSymbol.namespace - val rustType = RustType.Opaque(constraintViolationName, namespace) + val module = builderSymbol.module() + val rustType = RustType.Opaque(constraintViolationName, module.fullyQualifiedPath()) Symbol.builder() .rustType(rustType) .name(rustType.name) - .namespace(rustType.namespace, "::") - .definitionFile(builderSymbol.definitionFile) + .locatedIn(module) .build() } - is StringShape -> { - val namespace = "crate::${Models.namespace}::${ - RustReservedWords.escapeIfNeeded( - shape.contextName(serviceShape).toSnakeCase(), - ) - }" - val rustType = RustType.Opaque(constraintViolationName, namespace) + is StringShape, is IntegerShape -> { + val module = shape.shapeModule() + val rustType = RustType.Opaque(constraintViolationName, module.fullyQualifiedPath()) Symbol.builder() .rustType(rustType) .name(rustType.name) - .namespace(rustType.namespace, "::") - .definitionFile(Models.filename) + .locatedIn(module) .build() } + else -> TODO("Constraint traits on other shapes not implemented yet: $shape") } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt index 82102be18a0..592fe742827 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.neighbor.Walker import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.Shape @@ -21,6 +22,7 @@ import software.amazon.smithy.model.traits.LengthTrait import software.amazon.smithy.model.traits.PatternTrait import software.amazon.smithy.model.traits.RangeTrait import software.amazon.smithy.model.traits.RequiredTrait +import software.amazon.smithy.model.traits.Trait import software.amazon.smithy.model.traits.UniqueItemsTrait import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE @@ -42,6 +44,8 @@ fun Shape.hasConstraintTrait() = hasTrait() || hasTrait() +val supportedStringConstraintTraits: List> = listOf(LengthTrait::class.java, PatternTrait::class.java) + /** * We say a shape is _directly_ constrained if: * @@ -65,8 +69,10 @@ fun Shape.isDirectlyConstrained(symbolProvider: SymbolProvider): Boolean = when // `required`, so we can't use `member.isOptional` here. this.members().map { symbolProvider.toSymbol(it) }.any { !it.isOptional() } } + is MapShape -> this.hasTrait() - is StringShape -> this.hasTrait() || this.hasTrait() + is StringShape -> this.hasTrait() || supportedStringConstraintTraits.any { this.hasTrait(it) } + is IntegerShape -> this.hasTrait() else -> false } @@ -91,7 +97,8 @@ fun MemberShape.targetCanReachConstrainedShape(model: Model, symbolProvider: Sym fun Shape.hasPublicConstrainedWrapperTupleType(model: Model, publicConstrainedTypes: Boolean): Boolean = when (this) { is MapShape -> publicConstrainedTypes && this.hasTrait() - is StringShape -> !this.hasTrait() && (publicConstrainedTypes && this.hasTrait()) + is StringShape -> !this.hasTrait() && (publicConstrainedTypes && supportedStringConstraintTraits.any(this::hasTrait)) + is IntegerShape -> publicConstrainedTypes && this.hasTrait() is MemberShape -> model.expectShape(this.target).hasPublicConstrainedWrapperTupleType(model, publicConstrainedTypes) else -> false } @@ -125,7 +132,9 @@ fun Shape.typeNameContainsNonPublicType( publicConstrainedTypes: Boolean, ): Boolean = !publicConstrainedTypes && when (this) { is SimpleShape -> wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled(model) - is MemberShape -> model.expectShape(this.target).typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) + is MemberShape -> model.expectShape(this.target) + .typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) + is CollectionShape -> this.canReachConstrainedShape(model, symbolProvider) is MapShape -> this.canReachConstrainedShape(model, symbolProvider) is StructureShape, is UnionShape -> false diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PatternTraitValidationErrorMessage.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PatternTraitValidationErrorMessage.kt new file mode 100644 index 00000000000..8bb3cb648e1 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PatternTraitValidationErrorMessage.kt @@ -0,0 +1,12 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.model.traits.PatternTrait + +@Suppress("UnusedReceiverParameter") +fun PatternTrait.validationErrorMessage(): String = + "Value {} at '{}' failed to satisfy constraint: Member must satisfy regular expression pattern: {}" diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt index e63e18c7ac4..bb818eaf078 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt @@ -16,13 +16,16 @@ import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.SimpleShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustType -import software.amazon.smithy.rust.codegen.core.smithy.Constrained +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.smithy.ConstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.handleOptionality import software.amazon.smithy.rust.codegen.core.smithy.handleRustBoxing +import software.amazon.smithy.rust.codegen.core.smithy.locatedIn import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.toPascalCase @@ -68,13 +71,17 @@ class PubCrateConstrainedShapeSymbolProvider( check(shape is CollectionShape || shape is MapShape) val name = constrainedTypeNameForCollectionOrMapShape(shape, serviceShape) - val namespace = "crate::${Constrained.namespace}::${RustReservedWords.escapeIfNeeded(name.toSnakeCase())}" - val rustType = RustType.Opaque(name, namespace) + val module = RustModule.new( + RustReservedWords.escapeIfNeeded(name.toSnakeCase()), + visibility = Visibility.PUBCRATE, + parent = ConstrainedModule, + inline = true, + ) + val rustType = RustType.Opaque(name, module.fullyQualifiedPath()) return Symbol.builder() .rustType(rustType) .name(rustType.name) - .namespace(rustType.namespace, "::") - .definitionFile(Constrained.filename) + .locatedIn(module) .build() } @@ -88,6 +95,7 @@ class PubCrateConstrainedShapeSymbolProvider( is CollectionShape, is MapShape -> { constrainedSymbolForCollectionOrMapShape(shape) } + is MemberShape -> { require(model.expectShape(shape.container).isStructureShape) { "This arm is only exercised by `ServerBuilderGenerator`" @@ -101,13 +109,20 @@ class PubCrateConstrainedShapeSymbolProvider( } else { val targetSymbol = this.toSymbol(targetShape) // Handle boxing first so we end up with `Option>`, not `Box>`. - handleOptionality(handleRustBoxing(targetSymbol, shape), shape, nullableIndex, base.config().nullabilityCheckMode) + handleOptionality( + handleRustBoxing(targetSymbol, shape), + shape, + nullableIndex, + base.config().nullabilityCheckMode, + ) } } + is StructureShape, is UnionShape -> { // Structure shapes and union shapes always generate a [RustType.Opaque] constrained type. base.toSymbol(shape) } + else -> { check(shape is SimpleShape) // The rest of the shape types are simple shapes, which are impossible to be transitively but not diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstraintViolationSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstraintViolationSymbolProvider.kt index 05a8d635a4b..4e7b5ab6c72 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstraintViolationSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstraintViolationSymbolProvider.kt @@ -8,8 +8,11 @@ package software.amazon.smithy.rust.codegen.server.smithy import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.locatedIn +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.rustType /** @@ -28,10 +31,11 @@ class PubCrateConstraintViolationSymbolProvider( return baseSymbol } val baseRustType = baseSymbol.rustType() - val newNamespace = baseSymbol.namespace + "_internal" + val oldModule = baseSymbol.module() as RustModule.LeafModule + val newModule = oldModule.copy(name = oldModule.name + "_internal") return baseSymbol.toBuilder() - .rustType(RustType.Opaque(baseRustType.name, newNamespace)) - .namespace(newNamespace, baseSymbol.namespaceDelimiter) + .rustType(RustType.Opaque(baseRustType.name, newModule.fullyQualifiedPath())) + .locatedIn(newModule) .build() } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RangeTraitValidationErrorMessage.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RangeTraitValidationErrorMessage.kt new file mode 100644 index 00000000000..5512da84704 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RangeTraitValidationErrorMessage.kt @@ -0,0 +1,21 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.model.traits.RangeTrait + +fun RangeTrait.validationErrorMessage(): String { + val beginning = "Value {} at '{}' failed to satisfy constraint: Member must be " + val ending = if (this.min.isPresent && this.max.isPresent) { + "between ${this.min.get()} and ${this.max.get()}, inclusive" + } else if (this.min.isPresent) ( + "greater than or equal to ${this.min.get()}" + ) else { + check(this.max.isPresent) + "less than or equal to ${this.max.get()}" + } + return "$beginning$ending" +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt index 52da4fedb06..ec429d7b365 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt @@ -8,6 +8,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.CratesIo import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig /** @@ -25,6 +26,7 @@ object ServerCargoDependency { val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2")) val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.8.4"), scope = DependencyScope.Dev) + val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5")) fun SmithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http-server") fun SmithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("types") @@ -44,7 +46,8 @@ object ServerCargoDependency { object ServerInlineDependency { fun serverOperationHandler(runtimeConfig: RuntimeConfig): InlineDependency = InlineDependency.forRustFile( - "server_operation_handler_trait", + RustModule.private("server_operation_handler_trait"), + "/inlineable/src/server_operation_handler_trait.rs", ServerCargoDependency.SmithyHttpServer(runtimeConfig), CargoDependency.Http, ServerCargoDependency.PinProjectLite, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 93d860cce55..8f094e80c8e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.neighbor.Walker import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.ServiceShape @@ -24,16 +25,15 @@ import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.LengthTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.client.smithy.customize.RustCodegenDecorator -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.Constrained +import software.amazon.smithy.rust.codegen.core.smithy.ConstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig -import software.amazon.smithy.rust.codegen.core.smithy.Unconstrained +import software.amazon.smithy.rust.codegen.core.smithy.UnconstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock @@ -44,6 +44,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveSha import software.amazon.smithy.rust.codegen.core.util.CommandFailed import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.runCommand +import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedIntegerGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedMapGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedStringGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedTraitForEnumGenerator @@ -67,13 +68,6 @@ import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsM import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger import java.util.logging.Logger -val DefaultServerPublicModules = setOf( - RustModule.Error, - RustModule.Model, - RustModule.Input, - RustModule.Output, -).associateBy { it.name } - /** * Entrypoint for server-side code generation. This class will walk the in-memory model and * generate all the needed types by calling the accept() function on the available shapes. @@ -92,10 +86,6 @@ open class ServerCodegenVisitor( protected var codegenContext: ServerCodegenContext protected var protocolGeneratorFactory: ProtocolGeneratorFactory protected var protocolGenerator: ServerProtocolGenerator - private val unconstrainedModule = - RustModule.private(Unconstrained.namespace, "Unconstrained types for constrained shapes.") - private val constrainedModule = - RustModule.private(Constrained.namespace, "Constrained types for constrained shapes.") init { val symbolVisitorConfig = @@ -139,7 +129,7 @@ open class ServerCodegenVisitor( serverSymbolProviders.pubCrateConstrainedShapeSymbolProvider, ) - rustCrate = RustCrate(context.fileManifest, codegenContext.symbolProvider, DefaultServerPublicModules, settings.codegenConfig) + rustCrate = RustCrate(context.fileManifest, codegenContext.symbolProvider, settings.codegenConfig) protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext) } @@ -295,7 +285,7 @@ open class ServerCodegenVisitor( ) ) { logger.info("[rust-server-codegen] Generating an unconstrained type for collection shape $shape") - rustCrate.withModule(unconstrainedModule) unconstrainedModuleWriter@{ + rustCrate.withModule(UnconstrainedModule) unconstrainedModuleWriter@{ rustCrate.withModule(ModelsModule) modelsModuleWriter@{ UnconstrainedCollectionGenerator( codegenContext, @@ -307,7 +297,7 @@ open class ServerCodegenVisitor( } logger.info("[rust-server-codegen] Generating a constrained type for collection shape $shape") - rustCrate.withModule(constrainedModule) { + rustCrate.withModule(ConstrainedModule) { PubCrateConstrainedCollectionGenerator(codegenContext, this, shape).render() } } @@ -321,13 +311,13 @@ open class ServerCodegenVisitor( ) if (renderUnconstrainedMap) { logger.info("[rust-server-codegen] Generating an unconstrained type for map $shape") - rustCrate.withModule(unconstrainedModule) { + rustCrate.withModule(UnconstrainedModule) { UnconstrainedMapGenerator(codegenContext, this, shape).render() } if (!shape.isDirectlyConstrained(codegenContext.symbolProvider)) { logger.info("[rust-server-codegen] Generating a constrained type for map $shape") - rustCrate.withModule(constrainedModule) { + rustCrate.withModule(ConstrainedModule) { PubCrateConstrainedMapGenerator(codegenContext, this, shape).render() } } @@ -363,6 +353,15 @@ open class ServerCodegenVisitor( stringShape(shape, ::serverEnumGeneratorFactory) } + override fun integerShape(shape: IntegerShape) { + if (shape.isDirectlyConstrained(codegenContext.symbolProvider)) { + logger.info("[rust-server-codegen] Generating a constrained integer $shape") + rustCrate.withModule(ModelsModule) { + ConstrainedIntegerGenerator(codegenContext, this, shape).render() + } + } + } + protected fun stringShape( shape: StringShape, enumShapeGeneratorFactory: (codegenContext: ServerCodegenContext, writer: RustWriter, shape: StringShape) -> ServerEnumGenerator, @@ -411,7 +410,7 @@ open class ServerCodegenVisitor( ) ) { logger.info("[rust-server-codegen] Generating an unconstrained type for union shape $shape") - rustCrate.withModule(unconstrainedModule) unconstrainedModuleWriter@{ + rustCrate.withModule(UnconstrainedModule) unconstrainedModuleWriter@{ rustCrate.withModule(ModelsModule) modelsModuleWriter@{ UnconstrainedUnionGenerator( codegenContext, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt index e2b77c90fd7..0e368d85175 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt @@ -46,7 +46,7 @@ class ServerSymbolProviders private constructor( symbolVisitorConfig, false, ), - model, service, publicConstrainedTypes, + model, publicConstrainedTypes, service, ), pubCrateConstrainedShapeSymbolProvider = PubCrateConstrainedShapeSymbolProvider( baseSymbolProvider, @@ -56,8 +56,8 @@ class ServerSymbolProviders private constructor( constraintViolationSymbolProvider = ConstraintViolationSymbolProvider( baseSymbolProvider, model, - service, publicConstrainedTypes, + service, ), ) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt index 9fa2182e6b2..5beb183c816 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt @@ -16,14 +16,18 @@ import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.smithy.Default import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.Unconstrained +import software.amazon.smithy.rust.codegen.core.smithy.UnconstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.contextName import software.amazon.smithy.rust.codegen.core.smithy.handleOptionality import software.amazon.smithy.rust.codegen.core.smithy.handleRustBoxing +import software.amazon.smithy.rust.codegen.core.smithy.locatedIn import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.setDefault import software.amazon.smithy.rust.codegen.core.smithy.symbolBuilder @@ -75,22 +79,39 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilde class UnconstrainedShapeSymbolProvider( private val base: RustSymbolProvider, private val model: Model, - private val serviceShape: ServiceShape, private val publicConstrainedTypes: Boolean, + private val serviceShape: ServiceShape, ) : WrappingSymbolProvider(base) { private val nullableIndex = NullableIndex.of(model) + /** + * Unconstrained type names are always suffixed with `Unconstrained` for clarity, even though we could dispense with it + * given that they all live inside the `unconstrained` module, so they don't collide with the constrained types. + */ + private fun unconstrainedTypeNameForCollectionOrMapOrUnionShape(shape: Shape): String { + check(shape is CollectionShape || shape is MapShape || shape is UnionShape) + // Normally, one could use the base symbol provider's name. However, in this case, the name will be `Vec` or + // `HashMap` because the symbol provider _does not_ newtype the shapes. However, for unconstrained shapes, + // we need to introduce a newtype that preserves the original name of the shape from smithy. To handle that, + // we load the name of the shape directly from the model prior to add `Unconstrained`. + return RustReservedWords.escapeIfNeeded(shape.contextName(serviceShape).toPascalCase() + "Unconstrained") + } + private fun unconstrainedSymbolForCollectionOrMapOrUnionShape(shape: Shape): Symbol { check(shape is CollectionShape || shape is MapShape || shape is UnionShape) - val name = unconstrainedTypeNameForCollectionOrMapOrUnionShape(shape, serviceShape) - val namespace = "crate::${Unconstrained.namespace}::${RustReservedWords.escapeIfNeeded(name.toSnakeCase())}" - val rustType = RustType.Opaque(name, namespace) + val name = unconstrainedTypeNameForCollectionOrMapOrUnionShape(shape) + val module = RustModule.new( + RustReservedWords.escapeIfNeeded(name.toSnakeCase()), + visibility = Visibility.PUBCRATE, + parent = UnconstrainedModule, + inline = true, + ) + val rustType = RustType.Opaque(name, module.fullyQualifiedPath()) return Symbol.builder() .rustType(rustType) .name(rustType.name) - .namespace(rustType.namespace, "::") - .definitionFile(Unconstrained.filename) + .locatedIn(module) .build() } @@ -103,6 +124,7 @@ class UnconstrainedShapeSymbolProvider( base.toSymbol(shape) } } + is MapShape -> { if (shape.canReachConstrainedShape(model, base)) { unconstrainedSymbolForCollectionOrMapOrUnionShape(shape) @@ -110,6 +132,7 @@ class UnconstrainedShapeSymbolProvider( base.toSymbol(shape) } } + is StructureShape -> { if (shape.canReachConstrainedShape(model, base)) { shape.serverBuilderSymbol(base, !publicConstrainedTypes) @@ -117,6 +140,7 @@ class UnconstrainedShapeSymbolProvider( base.toSymbol(shape) } } + is UnionShape -> { if (shape.canReachConstrainedShape(model, base)) { unconstrainedSymbolForCollectionOrMapOrUnionShape(shape) @@ -124,6 +148,7 @@ class UnconstrainedShapeSymbolProvider( base.toSymbol(shape) } } + is MemberShape -> { // There are only two cases where we use this symbol provider on a member shape. // @@ -138,13 +163,19 @@ class UnconstrainedShapeSymbolProvider( val targetShape = model.expectShape(shape.target) val targetSymbol = this.toSymbol(targetShape) // Handle boxing first so we end up with `Option>`, not `Box>`. - handleOptionality(handleRustBoxing(targetSymbol, shape), shape, nullableIndex, base.config().nullabilityCheckMode) + handleOptionality( + handleRustBoxing(targetSymbol, shape), + shape, + nullableIndex, + base.config().nullabilityCheckMode, + ) } else { base.toSymbol(shape) } // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Constraint traits on member shapes are not // implemented yet. } + is StringShape -> { if (shape.canReachConstrainedShape(model, base)) { symbolBuilder(shape, RustType.String).setDefault(Default.RustDefault).build() @@ -152,15 +183,7 @@ class UnconstrainedShapeSymbolProvider( base.toSymbol(shape) } } + else -> base.toSymbol(shape) } } - -/** - * Unconstrained type names are always suffixed with `Unconstrained` for clarity, even though we could dispense with it - * given that they all live inside the `unconstrained` module, so they don't collide with the constrained types. - */ -fun unconstrainedTypeNameForCollectionOrMapOrUnionShape(shape: Shape, serviceShape: ServiceShape): String { - check(shape is CollectionShape || shape is MapShape || shape is UnionShape) - return "${shape.id.getName(serviceShape).toPascalCase()}Unconstrained" -} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt index 871a8d364c5..76f5fd1d258 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt @@ -10,13 +10,13 @@ import software.amazon.smithy.model.neighbor.Walker import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.EnumShape +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.SetShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.LengthTrait @@ -40,13 +40,17 @@ private sealed class UnsupportedConstraintMessageKind { """ $intro This is not supported in the smithy-rs server SDK. - ${ if (willSupport) "It will be supported in the future." else "" } + ${if (willSupport) "It will be supported in the future." else ""} See the tracking issue ($trackingIssue). - If you want to go ahead and generate the server SDK ignoring unsupported constraint traits, set the key `ignoreUnsupportedConstraintTraits` + If you want to go ahead and generate the server SDK ignoring unsupported constraint traits, set the key `ignoreUnsupportedConstraints` inside the `runtimeConfig.codegenConfig` JSON object in your `smithy-build.json` to `true`. """.trimIndent().replace("\n", " ") - fun buildMessageShapeHasUnsupportedConstraintTrait(shape: Shape, constraintTrait: Trait, trackingIssue: String) = + fun buildMessageShapeHasUnsupportedConstraintTrait( + shape: Shape, + constraintTrait: Trait, + trackingIssue: String, + ) = buildMessage( "The ${shape.type} shape `${shape.id}` has the constraint trait `${constraintTrait.toShapeId()}` attached.", willSupport = true, @@ -60,6 +64,7 @@ private sealed class UnsupportedConstraintMessageKind { level, buildMessageShapeHasUnsupportedConstraintTrait(shape, constraintTrait, constraintTraitsUberIssue), ) + is UnsupportedConstraintOnShapeReachableViaAnEventStream -> LogMessage( level, buildMessage( @@ -71,6 +76,7 @@ private sealed class UnsupportedConstraintMessageKind { "https://github.com/awslabs/smithy/issues/1388", ), ) + is UnsupportedLengthTraitOnStreamingBlobShape -> LogMessage( level, buildMessage( @@ -82,18 +88,17 @@ private sealed class UnsupportedConstraintMessageKind { "https://github.com/awslabs/smithy/issues/1389", ), ) + is UnsupportedLengthTraitOnCollectionOrOnBlobShape -> LogMessage( level, buildMessageShapeHasUnsupportedConstraintTrait(shape, lengthTrait, constraintTraitsUberIssue), ) - is UnsupportedPatternTraitOnStringShape -> LogMessage( - level, - buildMessageShapeHasUnsupportedConstraintTrait(shape, patternTrait, constraintTraitsUberIssue), - ) + is UnsupportedRangeTraitOnShape -> LogMessage( level, buildMessageShapeHasUnsupportedConstraintTrait(shape, rangeTrait, constraintTraitsUberIssue), ) + is UnsupportedUniqueItemsTraitOnShape -> LogMessage( level, buildMessageShapeHasUnsupportedConstraintTrait(shape, uniqueItemsTrait, constraintTraitsUberIssue), @@ -101,14 +106,28 @@ private sealed class UnsupportedConstraintMessageKind { } } } + private data class OperationWithConstrainedInputWithoutValidationException(val shape: OperationShape) -private data class UnsupportedConstraintOnMemberShape(val shape: MemberShape, val constraintTrait: Trait) : UnsupportedConstraintMessageKind() -private data class UnsupportedConstraintOnShapeReachableViaAnEventStream(val shape: Shape, val constraintTrait: Trait) : UnsupportedConstraintMessageKind() -private data class UnsupportedLengthTraitOnStreamingBlobShape(val shape: BlobShape, val lengthTrait: LengthTrait, val streamingTrait: StreamingTrait) : UnsupportedConstraintMessageKind() -private data class UnsupportedLengthTraitOnCollectionOrOnBlobShape(val shape: Shape, val lengthTrait: LengthTrait) : UnsupportedConstraintMessageKind() -private data class UnsupportedPatternTraitOnStringShape(val shape: Shape, val patternTrait: PatternTrait) : UnsupportedConstraintMessageKind() -private data class UnsupportedRangeTraitOnShape(val shape: Shape, val rangeTrait: RangeTrait) : UnsupportedConstraintMessageKind() -private data class UnsupportedUniqueItemsTraitOnShape(val shape: Shape, val uniqueItemsTrait: UniqueItemsTrait) : UnsupportedConstraintMessageKind() +private data class UnsupportedConstraintOnMemberShape(val shape: MemberShape, val constraintTrait: Trait) : + UnsupportedConstraintMessageKind() + +private data class UnsupportedConstraintOnShapeReachableViaAnEventStream(val shape: Shape, val constraintTrait: Trait) : + UnsupportedConstraintMessageKind() + +private data class UnsupportedLengthTraitOnStreamingBlobShape( + val shape: BlobShape, + val lengthTrait: LengthTrait, + val streamingTrait: StreamingTrait, +) : UnsupportedConstraintMessageKind() + +private data class UnsupportedLengthTraitOnCollectionOrOnBlobShape(val shape: Shape, val lengthTrait: LengthTrait) : + UnsupportedConstraintMessageKind() + +private data class UnsupportedRangeTraitOnShape(val shape: Shape, val rangeTrait: RangeTrait) : + UnsupportedConstraintMessageKind() + +private data class UnsupportedUniqueItemsTraitOnShape(val shape: Shape, val uniqueItemsTrait: UniqueItemsTrait) : + UnsupportedConstraintMessageKind() data class LogMessage(val level: Level, val message: String) data class ValidationResult(val shouldAbort: Boolean, val messages: List) @@ -123,7 +142,10 @@ private val allConstraintTraits = setOf( ) private val unsupportedConstraintsOnMemberShapes = allConstraintTraits - RequiredTrait::class.java -fun validateOperationsWithConstrainedInputHaveValidationExceptionAttached(model: Model, service: ServiceShape): ValidationResult { +fun validateOperationsWithConstrainedInputHaveValidationExceptionAttached( + model: Model, + service: ServiceShape, +): ValidationResult { // Traverse the model and error out if an operation uses constrained input, but it does not have // `ValidationException` attached in `errors`. https://github.com/awslabs/smithy-rs/pull/1199#discussion_r809424783 // TODO(https://github.com/awslabs/smithy-rs/issues/1401): This check will go away once we add support for @@ -146,20 +168,20 @@ fun validateOperationsWithConstrainedInputHaveValidationExceptionAttached(model: LogMessage( Level.SEVERE, """ - Operation ${it.shape.id} takes in input that is constrained - (https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html), and as such can fail with a validation + Operation ${it.shape.id} takes in input that is constrained + (https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html), and as such can fail with a validation exception. You must model this behavior in the operation shape in your model file. """.trimIndent().replace("\n", "") + """ - - ```smithy - use smithy.framework#ValidationException - - operation ${it.shape.id.name} { - ... - errors: [..., ValidationException] // <-- Add this. - } - ``` + + ```smithy + use smithy.framework#ValidationException + + operation ${it.shape.id.name} { + ... + errors: [..., ValidationException] // <-- Add this. + } + ``` """.trimIndent(), ) } @@ -167,7 +189,11 @@ fun validateOperationsWithConstrainedInputHaveValidationExceptionAttached(model: return ValidationResult(shouldAbort = messages.any { it.level == Level.SEVERE }, messages) } -fun validateUnsupportedConstraints(model: Model, service: ServiceShape, codegenConfig: ServerCodegenConfig): ValidationResult { +fun validateUnsupportedConstraints( + model: Model, + service: ServiceShape, + codegenConfig: ServerCodegenConfig, +): ValidationResult { // Traverse the model and error out if: val walker = Walker(model) @@ -214,32 +240,28 @@ fun validateUnsupportedConstraints(model: Model, service: ServiceShape, codegenC .map { UnsupportedLengthTraitOnCollectionOrOnBlobShape(it, it.expectTrait()) } .toSet() - // 5. Pattern trait on string shapes is used. It has not been implemented yet. - // TODO(https://github.com/awslabs/smithy-rs/issues/1401) - val unsupportedPatternTraitOnStringShapeSet = walker - .walkShapes(service) - .asSequence() - .filterIsInstance() - .filterMapShapesToTraits(setOf(PatternTrait::class.java)) - .map { (shape, patternTrait) -> UnsupportedPatternTraitOnStringShape(shape, patternTrait as PatternTrait) } - .toSet() - - // 6. Range trait on any shape is used. It has not been implemented yet. + // 5. Range trait used on a non-integer shape. It has not been implemented yet. // TODO(https://github.com/awslabs/smithy-rs/issues/1401) val unsupportedRangeTraitOnShapeSet = walker .walkShapes(service) .asSequence() + .filterNot { it is IntegerShape } .filterMapShapesToTraits(setOf(RangeTrait::class.java)) .map { (shape, rangeTrait) -> UnsupportedRangeTraitOnShape(shape, rangeTrait as RangeTrait) } .toSet() - // 7. UniqueItems trait on any shape is used. It has not been implemented yet. + // 6. UniqueItems trait on any shape is used. It has not been implemented yet. // TODO(https://github.com/awslabs/smithy-rs/issues/1401) val unsupportedUniqueItemsTraitOnShapeSet = walker .walkShapes(service) .asSequence() .filterMapShapesToTraits(setOf(UniqueItemsTrait::class.java)) - .map { (shape, uniqueItemsTrait) -> UnsupportedUniqueItemsTraitOnShape(shape, uniqueItemsTrait as UniqueItemsTrait) } + .map { (shape, uniqueItemsTrait) -> + UnsupportedUniqueItemsTraitOnShape( + shape, + uniqueItemsTrait as UniqueItemsTrait, + ) + } .toSet() val messages = @@ -247,7 +269,6 @@ fun validateUnsupportedConstraints(model: Model, service: ServiceShape, codegenC unsupportedLengthTraitOnStreamingBlobShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + unsupportedConstraintOnShapeReachableViaAnEventStreamSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + unsupportedLengthTraitOnCollectionOrOnBlobShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + - unsupportedPatternTraitOnStringShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + unsupportedRangeTraitOnShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + unsupportedUniqueItemsTraitOnShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt index 820f5bc8b7f..9fad74c0446 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType @@ -26,11 +27,8 @@ class BeforeIteratingOverMapJsonCustomization(private val codegenContext: Server codegenContext.settings.codegenConfig.publicConstrainedTypes, ) ) { - // Note that this particular implementation just so happens to work because when the customization - // is invoked in the JSON serializer, the value expression is guaranteed to be a variable binding name. - // If the expression in the future were to be more complex, we wouldn't be able to write the left-hand - // side of this assignment. - rust("""let ${section.valueExpression.name} = &${section.valueExpression.name}.0;""") + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name}.0") } } else -> emptySection diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberJsonCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberJsonCustomization.kt new file mode 100644 index 00000000000..0308a9dd5f7 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberJsonCustomization.kt @@ -0,0 +1,38 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType + +/** + * A customization to, just before we serialize a _constrained_ shape in a JSON serializer, unwrap the wrapper + * newtype and take a shared reference to the actual unconstrained value within it. + */ +class BeforeSerializingMemberJsonCustomization(private val codegenContext: ServerCodegenContext) : JsonSerializerCustomization() { + override fun section(section: JsonSerializerSection): Writable = when (section) { + is JsonSerializerSection.BeforeSerializingNonNullMember -> writable { + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + if (section.shape is IntegerShape) { + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name}.0") + } + } + } + else -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedIntegerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedIntegerGenerator.kt new file mode 100644 index 00000000000..10c5e253960 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedIntegerGenerator.kt @@ -0,0 +1,216 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.traits.RangeTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.render +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.core.smithy.module +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput +import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage + +/** + * [ConstrainedIntegerGenerator] generates a wrapper newtype holding a constrained `i32`. + * This type can be built from unconstrained values, yielding a `ConstraintViolation` when the input does not satisfy + * the constraints. + */ +class ConstrainedIntegerGenerator( + val codegenContext: ServerCodegenContext, + val writer: RustWriter, + val shape: IntegerShape, +) { + val model = codegenContext.model + val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + + fun render() { + val rangeTrait = shape.expectTrait() + + val symbol = constrainedShapeSymbolProvider.toSymbol(shape) + val constrainedTypeName = symbol.name + val unconstrainedTypeName = RustType.Integer(32).render() + val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) + val constraintsInfo = listOf(Range(rangeTrait).toTraitInfo(unconstrainedTypeName)) + + val constrainedTypeVisibility = if (publicConstrainedTypes) { + Visibility.PUBLIC + } else { + Visibility.PUBCRATE + } + val constrainedTypeMetadata = RustMetadata( + Attribute.Derives( + setOf( + RuntimeType.Debug, + RuntimeType.Clone, + RuntimeType.PartialEq, + RuntimeType.Eq, + RuntimeType.Hash, + ), + ), + visibility = constrainedTypeVisibility, + ) + + writer.documentShape(shape, model, note = rustDocsNote(constrainedTypeName)) + constrainedTypeMetadata.render(writer) + writer.rust("struct $constrainedTypeName(pub(crate) $unconstrainedTypeName);") + + if (constrainedTypeVisibility == Visibility.PUBCRATE) { + Attribute.AllowUnused.render(writer) + } + writer.rustTemplate( + """ + impl $constrainedTypeName { + /// ${rustDocsInnerMethod(unconstrainedTypeName)} + pub fn inner(&self) -> &$unconstrainedTypeName { + &self.0 + } + + /// ${rustDocsIntoInnerMethod(unconstrainedTypeName)} + pub fn into_inner(self) -> $unconstrainedTypeName { + self.0 + } + } + + impl #{ConstrainedTrait} for $constrainedTypeName { + type Unconstrained = $unconstrainedTypeName; + } + + impl #{From}<$unconstrainedTypeName> for #{MaybeConstrained} { + fn from(value: $unconstrainedTypeName) -> Self { + Self::Unconstrained(value) + } + } + + impl #{Display} for $constrainedTypeName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ${shape.redactIfNecessary(model, "self.0")}.fmt(f) + } + } + + impl #{From}<$constrainedTypeName> for $unconstrainedTypeName { + fn from(value: $constrainedTypeName) -> Self { + value.into_inner() + } + } + """, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "ConstraintViolation" to constraintViolation, + "MaybeConstrained" to symbol.makeMaybeConstrained(), + "Display" to RuntimeType.Display, + "From" to RuntimeType.From, + "TryFrom" to RuntimeType.TryFrom, + "AsRef" to RuntimeType.AsRef, + ) + + writer.renderTryFrom(unconstrainedTypeName, constrainedTypeName, constraintViolation, constraintsInfo) + + writer.withInlineModule(constraintViolation.module()) { + rust( + """ + ##[derive(Debug, PartialEq)] + pub enum ${constraintViolation.name} { + Range($unconstrainedTypeName), + } + """, + ) + + if (shape.isReachableFromOperationInput()) { + rustBlock("impl ${constraintViolation.name}") { + rustBlockTemplate( + "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField", + "String" to RuntimeType.String, + ) { + rustBlock("match self") { + rust( + """ + Self::Range(value) => crate::model::ValidationExceptionField { + message: format!("${rangeTrait.validationErrorMessage()}", value, &path), + path, + }, + """, + ) + } + } + } + } + } + } +} + +private data class Range(val rangeTrait: RangeTrait) { + fun toTraitInfo(unconstrainedTypeName: String): TraitInfo = TraitInfo( + { rust("Self::check_range(value)?;") }, + { + docs("Error when an integer doesn't satisfy its `@range` requirements.") + rust("Range($unconstrainedTypeName)") + }, + { + rust( + """ + Self::Range(value) => crate::model::ValidationExceptionField { + message: format!("${rangeTrait.validationErrorMessage()}", value, &path), + path, + }, + """, + ) + }, + this::renderValidationFunction, + ) + + /** + * Renders a `check_range` function to validate the integer matches the + * required range indicated by the `@range` trait. + */ + private fun renderValidationFunction(constraintViolation: Symbol, unconstrainedTypeName: String): Writable = { + val valueVariableName = "value" + val condition = if (rangeTrait.min.isPresent && rangeTrait.max.isPresent) { + "(${rangeTrait.min.get()}..=${rangeTrait.max.get()}).contains(&$valueVariableName)" + } else if (rangeTrait.min.isPresent) { + "${rangeTrait.min.get()} <= $valueVariableName" + } else { + "$valueVariableName <= ${rangeTrait.max.get()}" + } + + rust( + """ + fn check_range($valueVariableName: $unconstrainedTypeName) -> Result<(), $constraintViolation> { + if $condition { + Ok(()) + } else { + Err($constraintViolation::Range($valueVariableName)) + } + } + """, + ) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt index a63801cce44..38d1a2acc72 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt @@ -5,26 +5,33 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.model.traits.PatternTrait +import software.amazon.smithy.model.traits.Trait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.join import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained -import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.smithy.module +import software.amazon.smithy.rust.codegen.core.util.PANIC +import software.amazon.smithy.rust.codegen.core.util.orNull import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.supportedStringConstraintTraits import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage @@ -49,23 +56,18 @@ class ConstrainedStringGenerator( PubCrateConstraintViolationSymbolProvider(this) } } + private val constraintsInfo: List = + supportedStringConstraintTraits + .mapNotNull { shape.getTrait(it).orNull() } + .map(StringTraitInfo::fromTrait) + .map(StringTraitInfo::toTraitInfo) fun render() { - val lengthTrait = shape.expectTrait() - val symbol = constrainedShapeSymbolProvider.toSymbol(shape) val name = symbol.name val inner = RustType.String.render() val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) - val condition = if (lengthTrait.min.isPresent && lengthTrait.max.isPresent) { - "(${lengthTrait.min.get()}..=${lengthTrait.max.get()}).contains(&length)" - } else if (lengthTrait.min.isPresent) { - "${lengthTrait.min.get()} <= length" - } else { - "length <= ${lengthTrait.max.get()}" - } - val constrainedTypeVisibility = if (publicConstrainedTypes) { Visibility.PUBLIC } else { @@ -85,55 +87,47 @@ class ConstrainedStringGenerator( if (constrainedTypeVisibility == Visibility.PUBCRATE) { Attribute.AllowUnused.render(writer) } - writer.rustTemplate( + writer.rust( """ impl $name { /// Extracts a string slice containing the entire underlying `String`. pub fn as_str(&self) -> &str { &self.0 } - + /// ${rustDocsInnerMethod(inner)} pub fn inner(&self) -> &$inner { &self.0 } - + /// ${rustDocsIntoInnerMethod(inner)} pub fn into_inner(self) -> $inner { self.0 } - } - + }""", + ) + + writer.renderTryFrom(inner, name, constraintViolation, constraintsInfo) + + writer.rustTemplate( + """ impl #{ConstrainedTrait} for $name { type Unconstrained = $inner; } - + impl #{From}<$inner> for #{MaybeConstrained} { fn from(value: $inner) -> Self { Self::Unconstrained(value) } } - + impl #{Display} for $name { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { ${shape.redactIfNecessary(model, "self.0")}.fmt(f) } } - - impl #{TryFrom}<$inner> for $name { - type Error = #{ConstraintViolation}; - - /// ${rustDocsTryFromMethod(name, inner)} - fn try_from(value: $inner) -> Result { - let length = value.chars().count(); - if $condition { - Ok(Self(value)) - } else { - Err(#{ConstraintViolation}::Length(length)) - } - } - } - + + impl #{From}<$name> for $inner { fn from(value: $name) -> Self { value.into_inner() @@ -145,39 +139,163 @@ class ConstrainedStringGenerator( "MaybeConstrained" to symbol.makeMaybeConstrained(), "Display" to RuntimeType.Display, "From" to RuntimeType.From, - "TryFrom" to RuntimeType.TryFrom, ) - val constraintViolationModuleName = constraintViolation.namespace.split(constraintViolation.namespaceDelimiter).last() - writer.withModule(RustModule(constraintViolationModuleName, RustMetadata(visibility = constrainedTypeVisibility))) { - rust( + writer.withInlineModule(constraintViolation.module()) { + renderConstraintViolationEnum(this, shape, constraintViolation) + } + } + + private fun renderConstraintViolationEnum(writer: RustWriter, shape: StringShape, constraintViolation: Symbol) { + writer.rustTemplate( + """ + ##[derive(Debug, PartialEq)] + pub enum ${constraintViolation.name} { + #{Variants:W} + } + """, + "Variants" to constraintsInfo.map { it.constraintViolationVariant }.join(",\n"), + ) + + if (shape.isReachableFromOperationInput()) { + writer.rustTemplate( """ - ##[derive(Debug, PartialEq)] - pub enum ${constraintViolation.name} { - Length(usize), + impl ${constraintViolation.name} { + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + #{ValidationExceptionFields:W} + } + } } """, + "String" to RuntimeType.String, + "ValidationExceptionFields" to constraintsInfo.map { it.asValidationExceptionField }.join("\n"), + ) + } + } +} +private data class Length(val lengthTrait: LengthTrait) : StringTraitInfo() { + override fun toTraitInfo(): TraitInfo = TraitInfo( + { rust("Self::check_length(&value)?;") }, + { + docs("Error when a string doesn't satisfy its `@length` requirements.") + rust("Length(usize)") + }, + { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${lengthTrait.validationErrorMessage()}", length, &path), + path, + }, + """, ) + }, + this::renderValidationFunction, + ) - if (shape.isReachableFromOperationInput()) { - rustBlock("impl ${constraintViolation.name}") { - rustBlockTemplate( - "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField", - "String" to RuntimeType.String, - ) { - rustBlock("match self") { - rust( - """ - Self::Length(length) => crate::model::ValidationExceptionField { - message: format!("${lengthTrait.validationErrorMessage()}", length, &path), - path, - }, - """, - ) - } - } + /** + * Renders a `check_length` function to validate the string matches the + * required length indicated by the `@length` trait. + */ + private fun renderValidationFunction(constraintViolation: Symbol, unconstrainedTypeName: String): Writable = { + val condition = if (lengthTrait.min.isPresent && lengthTrait.max.isPresent) { + "(${lengthTrait.min.get()}..=${lengthTrait.max.get()}).contains(&length)" + } else if (lengthTrait.min.isPresent) { + "${lengthTrait.min.get()} <= length" + } else { + "length <= ${lengthTrait.max.get()}" + } + + rust( + """ + fn check_length(string: &str) -> Result<(), $constraintViolation> { + let length = string.chars().count(); + + if $condition { + Ok(()) + } else { + Err($constraintViolation::Length(length)) } } + """, + ) + } +} + +private data class Pattern(val patternTrait: PatternTrait) : StringTraitInfo() { + override fun toTraitInfo(): TraitInfo { + val pattern = patternTrait.pattern + + return TraitInfo( + { rust("let value = Self::check_pattern(value)?;") }, + { + docs("Error when a string doesn't satisfy its `@pattern`.") + docs("Contains the String that failed the pattern.") + rust("Pattern(String)") + }, + { + rust( + """ + Self::Pattern(string) => crate::model::ValidationExceptionField { + message: format!("${patternTrait.validationErrorMessage()}", &string, &path, r##"$pattern"##), + path + }, + """, + ) + }, + this::renderValidationFunction, + ) + } + + /** + * Renders a `check_pattern` function to validate the string matches the + * supplied regex in the `@pattern` trait. + */ + private fun renderValidationFunction(constraintViolation: Symbol, unconstrainedTypeName: String): Writable { + val pattern = patternTrait.pattern + val errorMessageForUnsupportedRegex = + """The regular expression $pattern is not supported by the `regex` crate; feel free to file an issue under https://github.com/awslabs/smithy-rs/issues for support""" + + return { + rustTemplate( + """ + fn check_pattern(string: $unconstrainedTypeName) -> Result<$unconstrainedTypeName, $constraintViolation> { + let regex = Self::compile_regex(); + + if regex.is_match(&string) { + Ok(string) + } else { + Err($constraintViolation::Pattern(string)) + } + } + + pub fn compile_regex() -> &'static #{Regex}::Regex { + static REGEX: #{OnceCell}::sync::Lazy<#{Regex}::Regex> = #{OnceCell}::sync::Lazy::new(|| #{Regex}::Regex::new(r##"$pattern"##).expect(r##"$errorMessageForUnsupportedRegex"##)); + + ®EX + } + """, + "Regex" to ServerCargoDependency.Regex.toType(), + "OnceCell" to ServerCargoDependency.OnceCell.toType(), + ) } } } + +private sealed class StringTraitInfo { + companion object { + fun fromTrait(trait: Trait): StringTraitInfo = + when (trait) { + is PatternTrait -> { + Pattern(trait) + } + is LengthTrait -> { + Length(trait) + } + else -> PANIC("StringTraitInfo.fromTrait called with unsupported trait $trait") + } + } + + abstract fun toTraitInfo(): TraitInfo +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/DocHandlerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/DocHandlerGenerator.kt new file mode 100644 index 00000000000..d9a052b2802 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/DocHandlerGenerator.kt @@ -0,0 +1,64 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule +import software.amazon.smithy.rust.codegen.core.smithy.InputsModule +import software.amazon.smithy.rust.codegen.core.smithy.OutputsModule +import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.outputShape +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase + +/** +Generates a stub for use within documentation. + */ +class DocHandlerGenerator(private val operation: OperationShape, private val commentToken: String = "//", codegenContext: CodegenContext) { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val crateName = codegenContext.settings.moduleName.toSnakeCase() + + /** + * Returns the function signature for an operation handler implementation. Used in the documentation. + */ + private fun OperationShape.docSignature(): Writable { + val inputSymbol = symbolProvider.toSymbol(inputShape(model)) + val outputSymbol = symbolProvider.toSymbol(outputShape(model)) + val errorSymbol = errorSymbol(model, symbolProvider, CodegenTarget.SERVER) + + val outputT = if (errors.isEmpty()) { + outputSymbol.name + } else { + "Result<${outputSymbol.name}, ${errorSymbol.name}>" + } + + return writable { + if (!errors.isEmpty()) { + rust("$commentToken ## use $crateName::${ErrorsModule.name}::${errorSymbol.name};") + } + rust( + """ + $commentToken ## use $crateName::${InputsModule.name}::${inputSymbol.name}; + $commentToken ## use $crateName::${OutputsModule.name}::${outputSymbol.name}; + $commentToken async fn handler(input: ${inputSymbol.name}) -> $outputT { + $commentToken todo!() + $commentToken } + """.trimIndent(), + ) + } + } + + fun render(writer: RustWriter) { + operation.docSignature()(writer) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt index 684d83322fd..cfcf5a53e1e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt @@ -8,8 +8,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.traits.LengthTrait -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.rust @@ -17,6 +15,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider @@ -63,12 +62,7 @@ class MapConstraintViolationGenerator( } else { Visibility.PUBCRATE } - modelsModuleWriter.withModule( - RustModule( - constraintViolationSymbol.namespace.split(constraintViolationSymbol.namespaceDelimiter).last(), - RustMetadata(visibility = constraintViolationVisibility), - ), - ) { + modelsModuleWriter.withInlineModule(constraintViolationSymbol.module()) { // TODO(https://github.com/awslabs/smithy-rs/issues/1401) We should really have two `ConstraintViolation` // types here. One will just have variants for each constraint trait on the map shape, for use by the user. // The other one will have variants if the shape's key or value is directly or transitively constrained, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt index b789c2166df..66380dd6383 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt @@ -7,12 +7,10 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.MapShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained @@ -54,7 +52,6 @@ class PubCrateConstrainedCollectionGenerator( val constrainedSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) val unconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape) - val moduleName = constrainedSymbol.namespace.split(constrainedSymbol.namespaceDelimiter).last() val name = constrainedSymbol.name val innerShape = model.expectShape(shape.member.target) val innerConstrainedSymbol = if (innerShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) { @@ -71,7 +68,7 @@ class PubCrateConstrainedCollectionGenerator( "From" to RuntimeType.From, ) - writer.withModule(RustModule(moduleName, RustMetadata(visibility = Visibility.PUBCRATE))) { + writer.withInlineModule(constrainedSymbol.module()) { rustTemplate( """ ##[derive(Debug, Clone)] @@ -105,38 +102,45 @@ class PubCrateConstrainedCollectionGenerator( """ impl #{From}<#{Symbol}> for $name { fn from(v: #{Symbol}) -> Self { - ${ if (innerNeedsConstraining) { + ${ + if (innerNeedsConstraining) { "Self(v.into_iter().map(|item| item.into()).collect())" } else { "Self(v)" - } } + } + } } } impl #{From}<$name> for #{Symbol} { fn from(v: $name) -> Self { - ${ if (innerNeedsConstraining) { + ${ + if (innerNeedsConstraining) { "v.0.into_iter().map(|item| item.into()).collect()" } else { "v.0" - } } + } + } } } """, *codegenScope, ) } else { - val innerNeedsConversion = innerShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) + val innerNeedsConversion = + innerShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) rustTemplate( """ impl #{From}<$name> for #{Symbol} { fn from(v: $name) -> Self { - ${ if (innerNeedsConversion) { + ${ + if (innerNeedsConversion) { "v.0.into_iter().map(|item| item.into()).collect()" } else { "v.0" - } } + } + } } } """, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt index 591b11b7ed0..d11bcad6b8d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt @@ -8,12 +8,10 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.StringShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained @@ -52,7 +50,6 @@ class PubCrateConstrainedMapGenerator( val symbol = symbolProvider.toSymbol(shape) val unconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape) val constrainedSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) - val moduleName = constrainedSymbol.namespace.split(constrainedSymbol.namespaceDelimiter).last() val name = constrainedSymbol.name val keyShape = model.expectShape(shape.key.target, StringShape::class.java) val valueShape = model.expectShape(shape.value.target) @@ -72,7 +69,7 @@ class PubCrateConstrainedMapGenerator( "From" to RuntimeType.From, ) - writer.withModule(RustModule(moduleName, RustMetadata(visibility = Visibility.PUBCRATE))) { + writer.withInlineModule(constrainedSymbol.module()) { rustTemplate( """ ##[derive(Debug, Clone)] diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt index a2c04d1a6eb..6828042be53 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt @@ -35,8 +35,6 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.DefaultTrait import software.amazon.smithy.model.traits.EnumDefinition import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility @@ -63,6 +61,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained import software.amazon.smithy.rust.codegen.core.smithy.makeOptional import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed import software.amazon.smithy.rust.codegen.core.smithy.mapRustType +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.util.dq @@ -148,7 +147,6 @@ class ServerBuilderGenerator( private val members: List = shape.allMembers.values.toList() private val structureSymbol = symbolProvider.toSymbol(shape) private val builderSymbol = shape.serverBuilderSymbol(codegenContext) - private val moduleName = builderSymbol.namespace.split(builderSymbol.namespaceDelimiter).last() private val isBuilderFallible = hasFallibleBuilder(shape, model, symbolProvider, takeInUnconstrainedTypes) private val serverBuilderConstraintViolations = ServerBuilderConstraintViolations(codegenContext, shape, takeInUnconstrainedTypes) @@ -163,7 +161,7 @@ class ServerBuilderGenerator( fun render(writer: RustWriter) { writer.docs("See #D.", structureSymbol) - writer.withModule(RustModule(moduleName, RustMetadata(visibility = visibility))) { + writer.withInlineModule(builderSymbol.module()) { renderBuilder(this) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt index fc6932b835f..e342d5c8cfd 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt @@ -9,7 +9,6 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock @@ -25,6 +24,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.makeOptional +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType @@ -83,7 +83,7 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes( fun render(writer: RustWriter) { writer.docs("See #D.", structureSymbol) - writer.withModule(RustModule.public(moduleName)) { + writer.withInlineModule(builderSymbol.module()) { renderBuilder(this) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt index a8ee7fd8f67..9720717383c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt @@ -8,14 +8,20 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext fun StructureShape.serverBuilderSymbol(codegenContext: ServerCodegenContext): Symbol = - this.serverBuilderSymbol(codegenContext.symbolProvider, !codegenContext.settings.codegenConfig.publicConstrainedTypes) + this.serverBuilderSymbol( + codegenContext.symbolProvider, + !codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) fun StructureShape.serverBuilderSymbol(symbolProvider: SymbolProvider, pubCrate: Boolean): Symbol { val structureSymbol = symbolProvider.toSymbol(this) @@ -25,11 +31,15 @@ fun StructureShape.serverBuilderSymbol(symbolProvider: SymbolProvider, pubCrate: } else { "" } - val rustType = RustType.Opaque("Builder", "${structureSymbol.namespace}::$builderNamespace") + val visibility = when (pubCrate) { + true -> Visibility.PUBCRATE + false -> Visibility.PUBLIC + } + val builderModule = RustModule.new(builderNamespace, visibility, parent = structureSymbol.module(), inline = true) + val rustType = RustType.Opaque("Builder", builderModule.fullyQualifiedPath()) return Symbol.builder() .rustType(rustType) .name(rustType.name) - .namespace(rustType.namespace, "::") - .definitionFile(structureSymbol.definitionFile) + .module(builderModule) .build() } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt index 1514750cda4..893f4d14f48 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider @@ -42,8 +43,8 @@ open class ServerEnumGenerator( ) override fun renderFromForStr() { - writer.withModule( - RustModule.public(constraintViolationSymbol.namespace.split(constraintViolationSymbol.namespaceDelimiter).last()), + writer.withInlineModule( + constraintViolationSymbol.module() as RustModule.LeafModule, ) { rustTemplate( """ diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt index e6e1c1ac400..18dfbd869d3 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt @@ -22,9 +22,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.Errors -import software.amazon.smithy.rust.codegen.core.smithy.Inputs -import software.amazon.smithy.rust.codegen.core.smithy.Outputs +import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule +import software.amazon.smithy.rust.codegen.core.smithy.InputsModule +import software.amazon.smithy.rust.codegen.core.smithy.OutputsModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.core.util.getTrait @@ -86,9 +86,9 @@ class ServerOperationRegistryGenerator( private fun renderOperationRegistryRustDocs(writer: RustWriter) { val inputOutputErrorsImport = if (operations.any { it.errors.isNotEmpty() }) { - "/// use ${crateName.toSnakeCase()}::{${Inputs.namespace}, ${Outputs.namespace}, ${Errors.namespace}};" + "/// use ${crateName.toSnakeCase()}::{${InputsModule.name}, ${OutputsModule.name}, ${ErrorsModule.name}};" } else { - "/// use ${crateName.toSnakeCase()}::{${Inputs.namespace}, ${Outputs.namespace}};" + "/// use ${crateName.toSnakeCase()}::{${InputsModule.name}, ${OutputsModule.name}};" } writer.rustTemplate( @@ -379,12 +379,13 @@ ${operationImplementationStubs(operations)} val outputSymbol = symbolProvider.toSymbol(outputShape(model)) val errorSymbol = errorSymbol(model, symbolProvider, CodegenTarget.SERVER) - val inputT = "${Inputs.namespace}::${inputSymbol.name}" - val t = "${Outputs.namespace}::${outputSymbol.name}" + // using module names here to avoid generating `crate::...` since we've already added the import + val inputT = "${InputsModule.name}::${inputSymbol.name}" + val t = "${OutputsModule.name}::${outputSymbol.name}" val outputT = if (errors.isEmpty()) { t } else { - val e = "${Errors.namespace}::${errorSymbol.name}" + val e = "${ErrorsModule.name}::${errorSymbol.name}" "Result<$t, $e>" } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationShapeGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationShapeGenerator.kt new file mode 100644 index 00000000000..8ae62beca1e --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationShapeGenerator.kt @@ -0,0 +1,71 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.util.toPascalCase +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency + +class ServerOperationShapeGenerator( + private val operations: List, + private val codegenContext: CodegenContext, +) { + + fun render(writer: RustWriter) { + if (operations.isEmpty()) { + return + } + + val firstOperation = codegenContext.symbolProvider.toSymbol(operations[0]) + val firstOperationName = firstOperation.name.toPascalCase() + val crateName = codegenContext.settings.moduleName.toSnakeCase() + + writer.rustTemplate( + """ + //! A collection of zero-sized types (ZSTs) representing each operation defined in the service closure. + //! + //! ## Constructing an [`Operation`](#{SmithyHttpServer}::operation::OperationShapeExt) + //! + //! To apply middleware to specific operations the [`Operation`](#{SmithyHttpServer}::operation::Operation) + //! API must be used. + //! + //! Using the [`OperationShapeExt`](#{SmithyHttpServer}::operation::OperationShapeExt) trait + //! implemented on each ZST we can construct an [`Operation`](#{SmithyHttpServer}::operation::Operation) + //! with appropriate constraints given by Smithy. + //! + //! #### Example + //! + //! ```no_run + //! use $crateName::operation_shape::$firstOperationName; + //! use #{SmithyHttpServer}::operation::OperationShapeExt; + #{Handler:W} + //! + //! let operation = $firstOperationName::from_handler(handler) + //! .layer(todo!("Provide a layer implementation")); + //! ``` + //! + //! ## Use as Marker Structs + //! + //! The [plugin system](#{SmithyHttpServer}::plugin) also makes use of these ZSTs to parameterize + //! [`Plugin`](#{SmithyHttpServer}::plugin::Plugin) implementations. The traits, such as + //! [`OperationShape`](#{SmithyHttpServer}::operation::OperationShape) can be used to provide + //! operation specific information to the [`Layer`](#{Tower}::Layer) being applied. + """.trimIndent(), + "SmithyHttpServer" to + ServerCargoDependency.SmithyHttpServer(codegenContext.runtimeConfig).toType(), + "Tower" to ServerCargoDependency.Tower.toType(), + "Handler" to DocHandlerGenerator(operations[0], "//!", codegenContext)::render, + ) + for (operation in operations) { + ServerOperationGenerator(codegenContext, operation).render(writer) + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt index ca6bd0fe5ae..e4082484836 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt @@ -35,6 +35,7 @@ open class ServerServiceGenerator( ) { private val index = TopDownIndex.of(codegenContext.model) protected val operations = index.getContainedOperations(codegenContext.serviceShape).sortedBy { it.id } + private val serviceName = codegenContext.serviceShape.id.name.toString() /** * Render Service Specific code. Code will end up in different files via [useShapeWriter]. See `SymbolVisitor.kt` @@ -42,7 +43,6 @@ open class ServerServiceGenerator( */ fun render() { rustCrate.lib { - val serviceName = codegenContext.serviceShape.id.name.toString() rust("##[doc(inline, hidden)]") rust("pub use crate::service::$serviceName;") } @@ -75,7 +75,7 @@ open class ServerServiceGenerator( // TODO(https://github.com/awslabs/smithy-rs/issues/1707): Remove, this is temporary. rustCrate.withModule( - RustModule( + RustModule.LeafModule( "operation_shape", RustMetadata( visibility = Visibility.PUBLIC, @@ -85,14 +85,12 @@ open class ServerServiceGenerator( ), ), ) { - for (operation in operations) { - ServerOperationGenerator(codegenContext, operation).render(this) - } + ServerOperationShapeGenerator(operations, codegenContext).render(this) } // TODO(https://github.com/awslabs/smithy-rs/issues/1707): Remove, this is temporary. rustCrate.withModule( - RustModule("service", RustMetadata(visibility = Visibility.PUBLIC, additionalAttributes = listOf(Attribute.DocHidden)), null), + RustModule.LeafModule("service", RustMetadata(visibility = Visibility.PUBLIC, additionalAttributes = listOf(Attribute.DocHidden)), null), ) { ServerServiceGeneratorV2( codegenContext, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt index 7b553ee119e..ca782f4da6b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt @@ -23,7 +23,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol class ServerServiceGeneratorV2( - codegenContext: CodegenContext, + private val codegenContext: CodegenContext, private val protocol: ServerProtocol, ) { private val runtimeConfig = codegenContext.runtimeConfig @@ -32,12 +32,14 @@ class ServerServiceGeneratorV2( arrayOf( "Bytes" to CargoDependency.Bytes.toType(), "Http" to CargoDependency.Http.toType(), + "SmithyHttp" to CargoDependency.smithyHttp(runtimeConfig).toType(), "HttpBody" to CargoDependency.HttpBody.toType(), "SmithyHttpServer" to smithyHttpServer, "Tower" to CargoDependency.Tower.toType(), ) private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider + val crateName = codegenContext.settings.moduleName.toSnakeCase() private val service = codegenContext.serviceShape private val serviceName = service.id.name.toPascalCase() @@ -101,6 +103,22 @@ class ServerServiceGeneratorV2( /// /// This should be an async function satisfying the [`Handler`](#{SmithyHttpServer}::operation::Handler) trait. /// See the [operation module documentation](#{SmithyHttpServer}::operation) for more information. + /// + /// ## Example + /// + /// ```no_run + /// use $crateName::$serviceName; + /// + #{Handler:W} + /// + /// let app = $serviceName::builder_without_plugins() + /// .$fieldName(handler) + /// /* Set other handlers */ + /// .build() + /// .unwrap(); + /// ## let app: $serviceName<#{SmithyHttpServer}::routing::Route<#{SmithyHttp}::body::SdkBody>> = app; + /// ``` + /// pub fn $fieldName(self, handler: HandlerType) -> Self where HandlerType: #{SmithyHttpServer}::operation::Handler, @@ -138,6 +156,7 @@ class ServerServiceGeneratorV2( } """, "Protocol" to protocol.markerStruct(), + "Handler" to DocHandlerGenerator(operationShape, "///", codegenContext)::render, *codegenScope, ) @@ -179,6 +198,9 @@ class ServerServiceGeneratorV2( /// Constructs a [`$serviceName`] from the arguments provided to the builder. /// /// Forgetting to register a handler for one or more operations will result in an error. + /// + /// Check out [`$builderName::build_unchecked`] if you'd prefer the service to return status code 500 when an + /// unspecified route requested. pub fn build(self) -> Result<$serviceName<#{SmithyHttpServer}::routing::Route<$builderBodyGenericTypeName>>, MissingOperationsError> { let router = { @@ -343,7 +365,7 @@ class ServerServiceGeneratorV2( /// Constructs a builder for [`$serviceName`]. /// - /// Use [`$serviceName::builder_without_plugins`] if you need to specify plugins. + /// Use [`$serviceName::builder_with_plugins`] if you need to specify plugins. pub fn builder_without_plugins() -> $builderName { Self::builder_with_plugins(#{SmithyHttpServer}::plugin::IdentityPlugin) } @@ -355,9 +377,9 @@ class ServerServiceGeneratorV2( #{SmithyHttpServer}::routing::IntoMakeService::new(self) } - /// Converts [`$serviceName`] into a [`MakeService`](tower::make::MakeService) with [`ConnectInfo`](#{SmithyHttpServer}::routing::into_make_service_with_connect_info::ConnectInfo). - pub fn into_make_service_with_connect_info(self) -> #{SmithyHttpServer}::routing::IntoMakeServiceWithConnectInfo { - #{SmithyHttpServer}::routing::IntoMakeServiceWithConnectInfo::new(self) + /// Converts [`$serviceName`] into a [`MakeService`](tower::make::MakeService) with [`ConnectInfo`](#{SmithyHttpServer}::request::connect_info::ConnectInfo). + pub fn into_make_service_with_connect_info(self) -> #{SmithyHttpServer}::request::connect_info::IntoMakeServiceWithConnectInfo { + #{SmithyHttpServer}::request::connect_info::IntoMakeServiceWithConnectInfo::new(self) } /// Applies a [`Layer`](#{Tower}::Layer) uniformly to all routes. @@ -415,6 +437,8 @@ class ServerServiceGeneratorV2( private fun missingOperationsError(): Writable = writable { rust( """ + /// The error encountered when calling the [`$builderName::build`] method if one or more operation handlers are not + /// specified. ##[derive(Debug)] pub struct MissingOperationsError { operation_names2setter_methods: std::collections::HashMap<&'static str, &'static str>, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/TraitInfo.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/TraitInfo.kt new file mode 100644 index 00000000000..afd8b55aaec --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/TraitInfo.kt @@ -0,0 +1,66 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.join +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType + +/** + * Information needed to render a constraint trait as Rust code. + */ +data class TraitInfo( + val tryFromCheck: Writable, + val constraintViolationVariant: Writable, + val asValidationExceptionField: Writable, + val validationFunctionDefinition: (constraintViolation: Symbol, unconstrainedTypeName: String) -> Writable, +) + +/** + * Render the implementation of `TryFrom` for a constrained type. + */ +fun RustWriter.renderTryFrom( + unconstrainedTypeName: String, + constrainedTypeName: String, + constraintViolationError: Symbol, + constraintsInfo: List, +) { + this.rustTemplate( + """ + impl $constrainedTypeName { + #{ValidationFunctions:W} + } + """, + "ValidationFunctions" to constraintsInfo.map { + it.validationFunctionDefinition( + constraintViolationError, + unconstrainedTypeName, + ) + } + .join("\n"), + ) + + this.rustTemplate( + """ + impl #{TryFrom}<$unconstrainedTypeName> for $constrainedTypeName { + type Error = #{ConstraintViolation}; + + /// ${rustDocsTryFromMethod(constrainedTypeName, unconstrainedTypeName)} + fn try_from(value: $unconstrainedTypeName) -> Result { + #{TryFromChecks:W} + + Ok(Self(value)) + } + } + """, + "TryFrom" to RuntimeType.TryFrom, + "ConstraintViolation" to constraintViolationError, + "TryFromChecks" to constraintsInfo.map { it.tryFromCheck }.join("\n"), + ) +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt index 602cbb7aafc..33a33ecff04 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt @@ -6,13 +6,11 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.shapes.CollectionShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape @@ -53,7 +51,6 @@ class UnconstrainedCollectionGenerator( check(shape.canReachConstrainedShape(model, symbolProvider)) val symbol = unconstrainedShapeSymbolProvider.toSymbol(shape) - val module = symbol.namespace.split(symbol.namespaceDelimiter).last() val name = symbol.name val innerShape = model.expectShape(shape.member.target) val innerUnconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(innerShape) @@ -62,21 +59,21 @@ class UnconstrainedCollectionGenerator( val constraintViolationName = constraintViolationSymbol.name val innerConstraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(innerShape) - unconstrainedModuleWriter.withModule(RustModule(module, RustMetadata(visibility = Visibility.PUBCRATE))) { + unconstrainedModuleWriter.withInlineModule(symbol.module()) { rustTemplate( """ ##[derive(Debug, Clone)] pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerUnconstrainedSymbol}>); - + impl From<$name> for #{MaybeConstrained} { fn from(value: $name) -> Self { Self::Unconstrained(value) } } - + impl #{TryFrom}<$name> for #{ConstrainedSymbol} { type Error = #{ConstraintViolationSymbol}; - + fn try_from(value: $name) -> Result { let res: Result<_, (usize, #{InnerConstraintViolationSymbol})> = value .0 @@ -84,7 +81,7 @@ class UnconstrainedCollectionGenerator( .enumerate() .map(|(idx, inner)| inner.try_into().map_err(|inner_violation| (idx, inner_violation))) .collect(); - res.map(Self) + res.map(Self) .map_err(|(idx, inner_violation)| #{ConstraintViolationSymbol}(idx, inner_violation)) } } @@ -98,17 +95,7 @@ class UnconstrainedCollectionGenerator( ) } - val constraintViolationVisibility = if (publicConstrainedTypes) { - Visibility.PUBLIC - } else { - Visibility.PUBCRATE - } - modelsModuleWriter.withModule( - RustModule( - constraintViolationSymbol.namespace.split(constraintViolationSymbol.namespaceDelimiter).last(), - RustMetadata(visibility = constraintViolationVisibility), - ), - ) { + modelsModuleWriter.withInlineModule(constraintViolationSymbol.module()) { // The first component of the tuple struct is the index in the collection where the first constraint // violation was found. rustTemplate( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt index 4d47eb62290..3d1f2a3898d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt @@ -7,16 +7,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.StringShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.join import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape @@ -66,22 +64,21 @@ class UnconstrainedMapGenerator( fun render() { check(shape.canReachConstrainedShape(model, symbolProvider)) - val module = symbol.namespace.split(symbol.namespaceDelimiter).last() val keySymbol = unconstrainedShapeSymbolProvider.toSymbol(keyShape) val valueSymbol = unconstrainedShapeSymbolProvider.toSymbol(valueShape) - unconstrainedModuleWriter.withModule(RustModule(module, RustMetadata(visibility = Visibility.PUBCRATE))) { + unconstrainedModuleWriter.withInlineModule(symbol.module()) { rustTemplate( """ ##[derive(Debug, Clone)] pub(crate) struct $name(pub(crate) std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}>); - + impl From<$name> for #{MaybeConstrained} { fn from(value: $name) -> Self { Self::Unconstrained(value) } } - + """, "KeySymbol" to keySymbol, "ValueSymbol" to valueSymbol, @@ -185,7 +182,7 @@ class UnconstrainedMapGenerator( // ``` rustTemplate( """ - let hm: std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}> = + let hm: std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}> = hm.into_iter().map(|(k, v)| (k, v.into())).collect(); """, "KeySymbol" to symbolProvider.toSymbol(keyShape), diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt index 7e3ed35e5ce..8ff79516f88 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt @@ -10,8 +10,6 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.rust @@ -24,6 +22,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.letIf @@ -71,13 +70,12 @@ class UnconstrainedUnionGenerator( fun render() { check(shape.canReachConstrainedShape(model, symbolProvider)) - val moduleName = symbol.namespace.split(symbol.namespaceDelimiter).last() val name = symbol.name val constrainedSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) val constraintViolationName = constraintViolationSymbol.name - unconstrainedModuleWriter.withModule(RustModule(moduleName, RustMetadata(visibility = Visibility.PUBCRATE))) { + unconstrainedModuleWriter.withInlineModule(symbol.module()) { rustBlock( """ ##[allow(clippy::enum_variant_names)] @@ -133,14 +131,11 @@ class UnconstrainedUnionGenerator( } else { Visibility.PUBCRATE } - modelsModuleWriter.withModule( - RustModule( - constraintViolationSymbol.namespace.split(constraintViolationSymbol.namespaceDelimiter).last(), - RustMetadata(visibility = constraintViolationVisibility), - ), + modelsModuleWriter.withInlineModule( + constraintViolationSymbol.module(), ) { Attribute.Derives(setOf(RuntimeType.Debug, RuntimeType.PartialEq)).render(this) - rustBlock("pub${ if (constraintViolationVisibility == Visibility.PUBCRATE) " (crate)" else "" } enum $constraintViolationName") { + rustBlock("pub${if (constraintViolationVisibility == Visibility.PUBCRATE) " (crate)" else ""} enum $constraintViolationName") { constraintViolations().forEach { renderConstraintViolation(this, it) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt index f866d83e3a4..b01c2f633e7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt @@ -75,7 +75,9 @@ class ServerRequestBindingGenerator( class ServerRequestAfterDeserializingIntoAHashMapOfHttpPrefixHeadersWrapInUnconstrainedMapHttpBindingCustomization(val codegenContext: ServerCodegenContext) : HttpBindingCustomization() { override fun section(section: HttpBindingSection): Writable = when (section) { - is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders -> emptySection + is HttpBindingSection.BeforeRenderingHeaderValue, + is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders, + -> emptySection is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders -> writable { if (section.memberShape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.unconstrainedShapeSymbolProvider)) { rust( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt index 2a030d10704..20e9f973632 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindi import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingSection import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType @@ -40,6 +41,9 @@ class ServerResponseBindingGenerator( ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstrainedMapHttpBindingCustomization( codegenContext, ), + ServerResponseBeforeRenderingHeadersHttpBindingCustomization( + codegenContext, + ), ), ) @@ -65,6 +69,36 @@ class ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstr rust("let ${section.variableName} = &${section.variableName}.0;") } } - is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders -> emptySection + + is HttpBindingSection.BeforeRenderingHeaderValue, + is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders, + -> emptySection + } +} + +/** + * A customization to, just before we render a _constrained_ member shape to an HTTP response header, + * unwrap the wrapper newtype and take a shared reference to the actual inner type within it. + */ +class ServerResponseBeforeRenderingHeadersHttpBindingCustomization(val codegenContext: ServerCodegenContext) : + HttpBindingCustomization() { + override fun section(section: HttpBindingSection): Writable = when (section) { + is HttpBindingSection.BeforeRenderingHeaderValue -> writable { + if (workingWithPublicConstrainedWrapperTupleType( + section.context.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + if (section.context.shape.isIntegerShape) { + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name.removePrefix("&")}.0") + } + } + } + + is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders, + is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders, + -> emptySection } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index cf981edb73c..60a26e54c26 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.protocoltests.traits.AppliesTo import software.amazon.smithy.protocoltests.traits.HttpMalformedRequestTestCase import software.amazon.smithy.protocoltests.traits.HttpMalformedRequestTestsTrait +import software.amazon.smithy.protocoltests.traits.HttpMalformedResponseBodyDefinition import software.amazon.smithy.protocoltests.traits.HttpMalformedResponseDefinition import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait @@ -175,7 +176,7 @@ class ServerProtocolTestGenerator( } } - val module = RustModule( + val module = RustModule.LeafModule( PROTOCOL_TEST_HELPER_MODULE_NAME, RustMetadata( additionalAttributes = listOf( @@ -184,9 +185,10 @@ class ServerProtocolTestGenerator( ), visibility = Visibility.PUBCRATE, ), + inline = true, ) - writer.withModule(module) { + writer.withInlineModule(module) { rustTemplate( """ use #{Tower}::Service as _; @@ -252,7 +254,7 @@ class ServerProtocolTestGenerator( if (allTests.isNotEmpty()) { val operationName = operationSymbol.name - val module = RustModule( + val module = RustModule.LeafModule( "server_${operationName.toSnakeCase()}_test", RustMetadata( additionalAttributes = listOf( @@ -261,8 +263,9 @@ class ServerProtocolTestGenerator( ), visibility = Visibility.PRIVATE, ), + inline = true, ) - writer.withModule(module) { + writer.withInlineModule(module) { renderAllTestCases(operationShape, allTests) } } @@ -339,8 +342,13 @@ class ServerProtocolTestGenerator( } is TestCase.MalformedRequestTest -> { - // We haven't found any broken `HttpMalformedRequestTest`s yet. - it + val howToFixIt = BrokenMalformedRequestTests[Pair(codegenContext.serviceShape.id.toString(), it.id)] + if (howToFixIt == null) { + it + } else { + val fixed = howToFixIt(it.testCase) + TestCase.MalformedRequestTest(fixed) + } } } } @@ -901,6 +909,7 @@ class ServerProtocolTestGenerator( private const val AwsJson11 = "aws.protocoltests.json#JsonProtocol" private const val RestJson = "aws.protocoltests.restjson#RestJson" private const val RestJsonValidation = "aws.protocoltests.restjson.validation#RestJsonValidation" + private const val MalformedRangeValidation = "aws.protocoltests.extras.restjson.validation#MalformedRangeValidation" private val ExpectFail: Set = setOf( // Pending merge from the Smithy team: see https://github.com/awslabs/smithy/pull/1477. FailingTest(RestJson, "RestJsonWithPayloadExpectsImpliedContentType", TestType.MalformedRequest), @@ -969,17 +978,6 @@ class ServerProtocolTestGenerator( FailingTest(RestJsonValidation, "RestJsonMalformedPatternStringOverride_case1", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedPatternUnionOverride_case0", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedPatternUnionOverride_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternList_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternList_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternMapKey_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternMapKey_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternMapValue_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternMapValue_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternReDOSString", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternString_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternString_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternUnion_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternUnion_case1", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeByteOverride_case0", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeByteOverride_case1", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloatOverride_case0", TestType.MalformedRequest), @@ -1000,6 +998,31 @@ class ServerProtocolTestGenerator( FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinFloat", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedPatternSensitiveString", TestType.MalformedRequest), + // Tests involving using @range on bytes, shorts and longs. + // See https://github.com/awslabs/smithy-rs/issues/1968 + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeShort_case0", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeShort_case1", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeLong_case0", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeLong_case1", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMaxShort", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMaxLong", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMinShort", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMinLong", TestType.MalformedRequest), + + // See https://github.com/awslabs/smithy-rs/issues/1969 + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeShortOverride_case0", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeShortOverride_case1", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeIntegerOverride_case0", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeIntegerOverride_case1", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeLongOverride_case0", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeLongOverride_case1", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMaxShortOverride", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMaxIntegerOverride", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMaxLongOverride", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMinShortOverride", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMinIntegerOverride", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMinLongOverride", TestType.MalformedRequest), + // Some tests for the S3 service (restXml). FailingTest("com.amazonaws.s3#AmazonS3", "GetBucketLocationUnwrappedOutput", TestType.Response), FailingTest("com.amazonaws.s3#AmazonS3", "S3DefaultAddressing", TestType.Request), @@ -1176,6 +1199,27 @@ class ServerProtocolTestGenerator( private fun fixRestJsonComplexErrorWithNoMessage(testCase: HttpResponseTestCase): HttpResponseTestCase = testCase.toBuilder().putHeader("X-Amzn-Errortype", "aws.protocoltests.restjson#ComplexError").build() + // TODO(https://github.com/awslabs/smithy/issues/1506) + private fun fixRestJsonMalformedPatternReDOSString(testCase: HttpMalformedRequestTestCase): HttpMalformedRequestTestCase { + val brokenResponse = testCase.response + val brokenBody = brokenResponse.body.get() + val fixedBody = HttpMalformedResponseBodyDefinition.builder() + .mediaType(brokenBody.mediaType) + .contents( + """ + { + "message" : "1 validation error detected. Value 000000000000000000000000000000000000000000000000000000000000000000000000000000000000! at '/evilString' failed to satisfy constraint: Member must satisfy regular expression pattern: ^([0-9]+)+${'$'}", + "fieldList" : [{"message": "Value 000000000000000000000000000000000000000000000000000000000000000000000000000000000000! at '/evilString' failed to satisfy constraint: Member must satisfy regular expression pattern: ^([0-9]+)+${'$'}", "path": "/evilString"}] + } + """.trimIndent(), + ) + .build() + + return testCase.toBuilder() + .response(brokenResponse.toBuilder().body(fixedBody).build()) + .build() + } + // These are tests whose definitions in the `awslabs/smithy` repository are wrong. // This is because they have not been written from a server perspective, and as such the expected `params` field is incomplete. // TODO(https://github.com/awslabs/smithy-rs/issues/1288): Contribute a PR to fix them upstream. @@ -1258,5 +1302,11 @@ class ServerProtocolTestGenerator( Pair(RestJson, "RestJsonEmptyComplexErrorWithNoMessage") to ::fixRestJsonEmptyComplexErrorWithNoMessage, Pair(RestJson, "RestJsonComplexErrorWithNoMessage") to ::fixRestJsonComplexErrorWithNoMessage, ) + + private val BrokenMalformedRequestTests: Map, KFunction1> = + // TODO(https://github.com/awslabs/smithy/issues/1506) + mapOf( + Pair(RestJsonValidation, "RestJsonMalformedPatternReDOSString") to ::fixRestJsonMalformedPatternReDOSString, + ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt index 815608fb249..ad50011edec 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt @@ -22,6 +22,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.Struc import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeIteratingOverMapJsonCustomization +import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeSerializingMemberJsonCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerAwsJsonProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol @@ -31,7 +32,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser */ class ServerAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGeneratorFactory { - override fun protocol(codegenContext: ServerCodegenContext): ServerProtocol = ServerAwsJsonProtocol(codegenContext, version) + override fun protocol(codegenContext: ServerCodegenContext): ServerProtocol = + ServerAwsJsonProtocol(codegenContext, version) override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = ServerHttpBoundProtocolGenerator(codegenContext, protocol(codegenContext)) @@ -71,6 +73,7 @@ class ServerAwsJsonError(private val awsJsonVersion: AwsJsonVersion) : JsonSeria rust("""${section.jsonObject}.key("__type").string("${escape(typeId)}");""") } } + else -> emptySection } } @@ -90,6 +93,10 @@ class ServerAwsJsonSerializerGenerator( codegenContext, httpBindingResolver, ::awsJsonFieldName, - customizations = listOf(ServerAwsJsonError(awsJsonVersion), BeforeIteratingOverMapJsonCustomization(codegenContext)), + customizations = listOf( + ServerAwsJsonError(awsJsonVersion), + BeforeIteratingOverMapJsonCustomization(codegenContext), + BeforeSerializingMemberJsonCustomization(codegenContext), + ), ), ) : StructuredDataSerializerGenerator by jsonSerializerGenerator diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt index a913b806d21..a53bb363da7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt @@ -14,6 +14,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonS import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeIteratingOverMapJsonCustomization +import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeSerializingMemberJsonCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRestJsonProtocol /** @@ -50,6 +51,9 @@ class ServerRestJsonSerializerGenerator( codegenContext, httpBindingResolver, ::restJsonFieldName, - customizations = listOf(BeforeIteratingOverMapJsonCustomization(codegenContext)), + customizations = listOf( + BeforeIteratingOverMapJsonCustomization(codegenContext), + BeforeSerializingMemberJsonCustomization(codegenContext), + ), ), ) : StructuredDataSerializerGenerator by jsonSerializerGenerator diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ShapeReachableFromOperationInputTagTrait.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ShapeReachableFromOperationInputTagTrait.kt index 14a2f720927..b4bd9a255b9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ShapeReachableFromOperationInputTagTrait.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ShapeReachableFromOperationInputTagTrait.kt @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.traits import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.Shape @@ -30,7 +31,7 @@ class ShapeReachableFromOperationInputTagTrait : AnnotationTrait(ID, Node.object } private fun isShapeReachableFromOperationInput(shape: Shape) = when (shape) { - is StructureShape, is UnionShape, is MapShape, is ListShape, is StringShape -> { + is StructureShape, is UnionShape, is MapShape, is ListShape, is StringShape, is IntegerShape -> { shape.hasTrait() } else -> PANIC("this method does not support shape type ${shape.type}") } @@ -40,3 +41,4 @@ fun StructureShape.isReachableFromOperationInput() = isShapeReachableFromOperati fun CollectionShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) fun UnionShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) fun MapShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) +fun IntegerShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt index cf58f3f9d93..0036238e6b2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.transformers import software.amazon.smithy.model.Model import software.amazon.smithy.model.neighbor.Walker +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.StringShape @@ -50,7 +51,7 @@ object ShapesReachableFromOperationInputTagger { return ModelTransformer.create().mapShapes(model) { shape -> when (shape) { - is StructureShape, is UnionShape, is ListShape, is MapShape, is StringShape -> { + is StructureShape, is UnionShape, is ListShape, is MapShape, is StringShape, is IntegerShape -> { if (shapesReachableFromOperationInputs.contains(shape)) { val builder = when (shape) { is StructureShape -> shape.toBuilder() @@ -58,6 +59,7 @@ object ShapesReachableFromOperationInputTagger { is ListShape -> shape.toBuilder() is MapShape -> shape.toBuilder() is StringShape -> shape.toBuilder() + is IntegerShape -> shape.toBuilder() else -> UNREACHABLE("the `when` is exhaustive") } builder.addTrait(ShapeReachableFromOperationInputTagTrait()).build() diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt index bcf7fe34ce5..fb6c2ec1c79 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt @@ -7,14 +7,20 @@ package software.amazon.smithy.rust.codegen.server.smithy import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider +import java.util.stream.Stream const val baseModelString = """ @@ -32,13 +38,17 @@ const val baseModelString = structure TestInputOutput { constrainedString: ConstrainedString, + constrainedInteger: ConstrainedInteger, constrainedMap: ConstrainedMap, unconstrainedMap: TransitivelyConstrainedMap } @length(min: 1, max: 69) string ConstrainedString - + + @range(min: 10, max: 29) + integer ConstrainedInteger + string UnconstrainedString @length(min: 1, max: 69) @@ -64,24 +74,33 @@ class ConstrainedShapeSymbolProviderTest { private val symbolProvider = serverTestSymbolProvider(model, serviceShape) private val constrainedShapeSymbolProvider = ConstrainedShapeSymbolProvider(symbolProvider, model, serviceShape) - private val constrainedMapShape = model.lookup("test#ConstrainedMap") - private val constrainedMapType = constrainedShapeSymbolProvider.toSymbol(constrainedMapShape).rustType() - - @Test - fun `it should return a constrained string type for a constrained string shape`() { - val constrainedStringShape = model.lookup("test#ConstrainedString") - val constrainedStringType = constrainedShapeSymbolProvider.toSymbol(constrainedStringShape).rustType() - - constrainedStringType shouldBe RustType.Opaque("ConstrainedString", "crate::model") + companion object { + @JvmStatic + fun getConstrainedShapes(): Stream = + Stream.of( + Arguments.of("ConstrainedInteger", { s: Shape -> s is IntegerShape }), + Arguments.of("ConstrainedString", { s: Shape -> s is StringShape }), + Arguments.of("ConstrainedMap", { s: Shape -> s is MapShape }), + ) } - @Test - fun `it should return a constrained map type for a constrained map shape`() { - constrainedMapType shouldBe RustType.Opaque("ConstrainedMap", "crate::model") + @ParameterizedTest + @MethodSource("getConstrainedShapes") + fun `it should return a constrained type for a constrained shape`( + shapeName: String, + shapeCheck: (Shape) -> Boolean, + ) { + val constrainedShape = model.lookup("test#$shapeName") + assert(shapeCheck(constrainedShape)) + val constrainedType = constrainedShapeSymbolProvider.toSymbol(constrainedShape).rustType() + + constrainedType shouldBe RustType.Opaque(shapeName, "crate::model") } @Test fun `it should not blindly delegate to the base symbol provider when the shape is an aggregate shape and is not directly constrained`() { + val constrainedMapShape = model.lookup("test#ConstrainedMap") + val constrainedMapType = constrainedShapeSymbolProvider.toSymbol(constrainedMapShape).rustType() val unconstrainedMapShape = model.lookup("test#TransitivelyConstrainedMap") val unconstrainedMapType = constrainedShapeSymbolProvider.toSymbol(unconstrainedMapShape).rustType() diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt index 80e2d93dae4..76821a90dd1 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt @@ -26,63 +26,58 @@ class ConstraintsTest { version: "123", operations: [TestOperation] } - + operation TestOperation { input: TestInputOutput, output: TestInputOutput, } - + structure TestInputOutput { map: MapA, - recursive: RecursiveShape } - + structure RecursiveShape { shape: RecursiveShape, mapB: MapB } - + @length(min: 1, max: 69) map MapA { key: String, value: MapB } - + map MapB { key: String, value: StructureA } - + @uniqueItems list ListA { member: MyString } - + @pattern("\\w+") string MyString - + @length(min: 1, max: 69) string LengthString - + structure StructureA { @range(min: 1, max: 69) int: Integer, - @required string: String } - + // This shape is not in the service closure. structure StructureB { @pattern("\\w+") patternString: String, - @required requiredString: String, - mapA: MapA, - @length(min: 1, max: 5) mapAPrecedence: MapA } @@ -94,7 +89,6 @@ class ConstraintsTest { private val mapA = model.lookup("test#MapA") private val mapB = model.lookup("test#MapB") private val listA = model.lookup("test#ListA") - private val myString = model.lookup("test#MyString") private val lengthString = model.lookup("test#LengthString") private val structA = model.lookup("test#StructureA") private val structAInt = model.lookup("test#StructureA\$int") @@ -114,7 +108,7 @@ class ConstraintsTest { @Test fun `it should not detect unsupported constrained traits as constrained`() { - listOf(structAInt, structAString, myString).forAll { + listOf(structAInt, structAString).forAll { it.isDirectlyConstrained(symbolProvider) shouldBe false } } @@ -123,9 +117,7 @@ class ConstraintsTest { fun `it should evaluate reachability of constrained shapes`() { mapA.canReachConstrainedShape(model, symbolProvider) shouldBe true structAInt.canReachConstrainedShape(model, symbolProvider) shouldBe false - - // This should be true when we start supporting the `pattern` trait on string shapes. - listA.canReachConstrainedShape(model, symbolProvider) shouldBe false + listA.canReachConstrainedShape(model, symbolProvider) shouldBe true // All of these eventually reach `StructureA`, which is constrained because one of its members is `required`. testInputOutput.canReachConstrainedShape(model, symbolProvider) shouldBe true diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt index 21baefe747d..f0b339a4852 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt @@ -23,24 +23,30 @@ import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymb class PubCrateConstrainedShapeSymbolProviderTest { private val model = """ $baseModelString - + + structure NonTransitivelyConstrainedStructureShape { + constrainedString: ConstrainedString, + constrainedMap: ConstrainedMap, + unconstrainedMap: TransitivelyConstrainedMap + } + list TransitivelyConstrainedCollection { member: Structure } - + structure Structure { @required requiredMember: String } - + structure StructureWithMemberTargetingAggregateShape { member: TransitivelyConstrainedCollection } - + union Union { structure: Structure } - """.asSmithyModel() + """.asSmithyModel() private val serverTestSymbolProviders = serverTestSymbolProviders(model) private val symbolProvider = serverTestSymbolProviders.symbolProvider @@ -97,7 +103,7 @@ class PubCrateConstrainedShapeSymbolProviderTest { @Test fun `it should delegate to the base symbol provider when provided with a structure shape`() { - val structureShape = model.lookup("test#TestInputOutput") + val structureShape = model.lookup("test#NonTransitivelyConstrainedStructureShape") val structureSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(structureShape) structureSymbol shouldBe symbolProvider.toSymbol(structureShape) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitorTest.kt index bbb697d93ef..e51d5b813af 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitorTest.kt @@ -13,7 +13,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.generatePluginContext import software.amazon.smithy.rust.codegen.server.smithy.customizations.ServerRequiredCustomizations import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator -import kotlin.io.path.createDirectory import kotlin.io.path.writeText class ServerCodegenVisitorTest { @@ -46,7 +45,6 @@ class ServerCodegenVisitorTest { } """.asSmithyModel(smithyVersion = "2.0") val (ctx, testDir) = generatePluginContext(model) - testDir.resolve("src").createDirectory() testDir.resolve("src/main.rs").writeText("fn main() {}") val codegenDecorator: CombinedCodegenDecorator = CombinedCodegenDecorator.fromClasspath(ctx, ServerRequiredCustomizations()) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt index 78cec7408be..25461f18754 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt @@ -22,12 +22,12 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { private val baseModel = """ namespace test - + service TestService { version: "123", operations: [TestOperation] } - + operation TestOperation { input: TestInputOutput, output: TestInputOutput, @@ -44,7 +44,7 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { val model = """ $baseModel - + structure TestInputOutput { @required requiredString: String @@ -62,7 +62,7 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { val model = """ $baseModel - + structure TestInputOutput { @length(min: 1, max: 69) lengthString: String @@ -79,7 +79,7 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { val model = """ $baseModel - + structure TestInputOutput { @required string: String @@ -93,12 +93,12 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { private val constraintTraitOnStreamingBlobShapeModel = """ $baseModel - + structure TestInputOutput { @required streamingBlob: StreamingBlob } - + @streaming @length(min: 69) blob StreamingBlob @@ -123,20 +123,20 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { val model = """ $baseModel - + structure TestInputOutput { eventStream: EventStream } - + @streaming union EventStream { message: Message } - + structure Message { lengthString: LengthString } - + @length(min: 1) string LengthString """.asSmithyModel() @@ -155,17 +155,17 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { val model = """ $baseModel - + structure TestInputOutput { collection: LengthCollection, blob: LengthBlob } - + @length(min: 1) list LengthCollection { member: String } - + @length(min: 1) blob LengthBlob """.asSmithyModel() @@ -177,41 +177,32 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { } @Test - fun `it should detect when the pattern trait on string shapes is used`() { + fun `it should detect when the range trait is used on a shape we do not support`() { val model = """ $baseModel - + structure TestInputOutput { - patternString: PatternString + rangeByte: RangeByte + rangeShort: RangeShort + rangeLong: RangeLong } - - @pattern("^[A-Za-z]+$") - string PatternString - """.asSmithyModel() - val validationResult = validateModel(model) - validationResult.messages shouldHaveSize 1 - validationResult.messages[0].message shouldContain "The string shape `test#PatternString` has the constraint trait `smithy.api#pattern` attached" - } + @range(min: 1) + byte RangeByte + + @range(min: 1) + long RangeLong - @Test - fun `it should detect when the range trait is used`() { - val model = - """ - $baseModel - - structure TestInputOutput { - rangeInteger: RangeInteger - } - @range(min: 1) - integer RangeInteger + short RangeShort """.asSmithyModel() val validationResult = validateModel(model) - validationResult.messages shouldHaveSize 1 - validationResult.messages[0].message shouldContain "The integer shape `test#RangeInteger` has the constraint trait `smithy.api#range` attached" + validationResult.messages shouldHaveSize 3 + validationResult.messages[0].message shouldContain "The long shape `test#RangeLong` has the constraint trait `smithy.api#range` attached" + validationResult.messages[1].message shouldContain "The short shape `test#RangeShort` has the constraint trait `smithy.api#range` attached" + validationResult.messages[2].message shouldContain "The byte shape `test#RangeByte` has the constraint trait `smithy.api#range` attached" } @Test @@ -219,11 +210,11 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { val model = """ $baseModel - + structure TestInputOutput { uniqueItemsList: UniqueItemsList } - + @uniqueItems list UniqueItemsList { member: String diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedIntegerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedIntegerGeneratorTest.kt new file mode 100644 index 00000000000..a889281327b --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedIntegerGeneratorTest.kt @@ -0,0 +1,124 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext +import java.util.stream.Stream + +class ConstrainedIntegerGeneratorTest { + + data class TestCase(val model: Model, val validInteger: Int, val invalidInteger: Int) + + class ConstrainedIntGeneratorTestProvider : ArgumentsProvider { + private val testCases = listOf( + // Min and max. + Triple("@range(min: 10, max: 12)", 11, 13), + // Min equal to max. + Triple("@range(min: 11, max: 11)", 11, 12), + // Only min. + Triple("@range(min: 11)", 12, 2), + // Only max. + Triple("@range(max: 11)", 0, 12), + ).map { + TestCase( + """ + namespace test + + ${it.first} + integer ConstrainedInteger + """.asSmithyModel(), + it.second, + it.third, + ) + } + + override fun provideArguments(context: ExtensionContext?): Stream = + testCases.map { Arguments.of(it) }.stream() + } + + @ParameterizedTest + @ArgumentsSource(ConstrainedIntGeneratorTestProvider::class) + fun `it should generate constrained integer types`(testCase: TestCase) { + val constrainedIntegerShape = testCase.model.lookup("test#ConstrainedInteger") + + val codegenContext = serverTestCodegenContext(testCase.model) + val symbolProvider = codegenContext.symbolProvider + + val project = TestWorkspace.testProject(symbolProvider) + + project.withModule(ModelsModule) { + ConstrainedIntegerGenerator(codegenContext, this, constrainedIntegerShape).render() + + unitTest( + name = "try_from_success", + test = """ + let _constrained: ConstrainedInteger = ${testCase.validInteger}.try_into().unwrap(); + """, + ) + unitTest( + name = "try_from_fail", + test = """ + let constrained_res: Result = ${testCase.invalidInteger}.try_into(); + constrained_res.unwrap_err(); + """, + ) + unitTest( + name = "inner", + test = """ + let constrained = ConstrainedInteger::try_from(${testCase.validInteger}).unwrap(); + assert_eq!(constrained.inner(), &${testCase.validInteger}); + """, + ) + unitTest( + name = "into_inner", + test = """ + let int = ${testCase.validInteger}; + let constrained = ConstrainedInteger::try_from(int).unwrap(); + + assert_eq!(constrained.into_inner(), int); + """, + ) + } + + project.compileAndTest() + } + + @Test + fun `type should not be constructible without using a constructor`() { + val model = """ + namespace test + + @range(min: -1, max: 69) + integer ConstrainedInteger + """.asSmithyModel() + val constrainedIntegerShape = model.lookup("test#ConstrainedInteger") + + val codegenContext = serverTestCodegenContext(model) + + val writer = RustWriter.forModule(ModelsModule.name) + + ConstrainedIntegerGenerator(codegenContext, writer, constrainedIntegerShape).render() + + // Check that the wrapped type is `pub(crate)`. + writer.toString() shouldContain "pub struct ConstrainedInteger(pub(crate) i32);" + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt index 75db6303f7a..ddf9a53d073 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt @@ -25,7 +25,6 @@ import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCode import java.util.stream.Stream class ConstrainedStringGeneratorTest { - data class TestCase(val model: Model, val validString: String, val invalidString: String) class ConstrainedStringGeneratorTestProvider : ArgumentsProvider { @@ -44,11 +43,20 @@ class ConstrainedStringGeneratorTest { "👍👍👍", // These three emojis are three Unicode scalar values. "👍👍👍👍", ), + Triple("@pattern(\"^[a-z]+$\")", "valid", "123 invalid"), + Triple( + """ + @length(min: 3, max: 10) + @pattern("^a string$") + """, + "a string", "an invalid string", + ), + Triple("@pattern(\"123\")", "some pattern 123 in the middle", "no pattern at all"), ).map { TestCase( """ namespace test - + ${it.first} string ConstrainedString """.asSmithyModel(), @@ -116,10 +124,10 @@ class ConstrainedStringGeneratorTest { fun `type should not be constructible without using a constructor`() { val model = """ namespace test - + @length(min: 1, max: 69) string ConstrainedString - """.asSmithyModel() + """.asSmithyModel() val constrainedStringShape = model.lookup("test#ConstrainedString") val codegenContext = serverTestCodegenContext(model) @@ -136,14 +144,14 @@ class ConstrainedStringGeneratorTest { fun `Display implementation`() { val model = """ namespace test - + @length(min: 1, max: 69) string ConstrainedString - + @sensitive @length(min: 1, max: 78) string SensitiveConstrainedString - """.asSmithyModel() + """.asSmithyModel() val constrainedStringShape = model.lookup("test#ConstrainedString") val sensitiveConstrainedStringShape = model.lookup("test#SensitiveConstrainedString") diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGeneratorTest.kt index d414aa63fc7..74658c3a674 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGeneratorTest.kt @@ -7,7 +7,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ServerCombinedErrorGenerator import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace @@ -52,7 +52,7 @@ class ServerCombinedErrorGeneratorTest { @Test fun `generates combined error enums`() { val project = TestWorkspace.testProject(symbolProvider) - project.withModule(RustModule.public("error")) { + project.withModule(ErrorsModule) { listOf("FooException", "ComplexError", "InvalidGreeting", "Deprecated").forEach { model.lookup("error#$it").serverRenderWithModelBuilder(model, symbolProvider, this) } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt index 42774274d9d..50543d67260 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt @@ -8,8 +8,9 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.smithy.ConstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.smithy.UnconstrainedModule import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest @@ -28,15 +29,15 @@ class UnconstrainedCollectionGeneratorTest { list ListA { member: ListB } - + list ListB { member: StructureC } - + structure StructureC { @required int: Integer, - + @required string: String } @@ -49,16 +50,16 @@ class UnconstrainedCollectionGeneratorTest { val project = TestWorkspace.testProject(symbolProvider) - project.withModule(RustModule.public("model")) { + project.withModule(ModelsModule) { model.lookup("test#StructureC").serverRenderWithModelBuilder(model, symbolProvider, this) } - project.withModule(RustModule.private("constrained")) { + project.withModule(ConstrainedModule) { listOf(listA, listB).forEach { PubCrateConstrainedCollectionGenerator(codegenContext, this, it).render() } } - project.withModule(RustModule.private("unconstrained")) unconstrainedModuleWriter@{ + project.withModule(UnconstrainedModule) unconstrainedModuleWriter@{ project.withModule(ModelsModule) modelsModuleWriter@{ listOf(listA, listB).forEach { UnconstrainedCollectionGenerator( @@ -72,53 +73,53 @@ class UnconstrainedCollectionGeneratorTest { this@unconstrainedModuleWriter.unitTest( name = "list_a_unconstrained_fail_to_constrain_with_first_error", test = """ - let c_builder1 = crate::model::StructureC::builder().int(69); - let c_builder2 = crate::model::StructureC::builder().string("david".to_owned()); - let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder1, c_builder2]); - let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]); - - let expected_err = - crate::model::list_a::ConstraintViolation(0, crate::model::list_b::ConstraintViolation( - 0, crate::model::structure_c::ConstraintViolation::MissingString, - )); - - assert_eq!( - expected_err, - crate::constrained::list_a_constrained::ListAConstrained::try_from(list_a_unconstrained).unwrap_err() - ); + let c_builder1 = crate::model::StructureC::builder().int(69); + let c_builder2 = crate::model::StructureC::builder().string("david".to_owned()); + let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder1, c_builder2]); + let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]); + + let expected_err = + crate::model::list_a::ConstraintViolation(0, crate::model::list_b::ConstraintViolation( + 0, crate::model::structure_c::ConstraintViolation::MissingString, + )); + + assert_eq!( + expected_err, + crate::constrained::list_a_constrained::ListAConstrained::try_from(list_a_unconstrained).unwrap_err() + ); """, ) this@unconstrainedModuleWriter.unitTest( name = "list_a_unconstrained_succeed_to_constrain", test = """ - let c_builder = crate::model::StructureC::builder().int(69).string(String::from("david")); - let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder]); - let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]); - - let expected: Vec> = vec![vec![crate::model::StructureC { - string: "david".to_owned(), - int: 69 - }]]; - let actual: Vec> = - crate::constrained::list_a_constrained::ListAConstrained::try_from(list_a_unconstrained).unwrap().into(); - - assert_eq!(expected, actual); + let c_builder = crate::model::StructureC::builder().int(69).string(String::from("david")); + let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder]); + let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]); + + let expected: Vec> = vec![vec![crate::model::StructureC { + string: "david".to_owned(), + int: 69 + }]]; + let actual: Vec> = + crate::constrained::list_a_constrained::ListAConstrained::try_from(list_a_unconstrained).unwrap().into(); + + assert_eq!(expected, actual); """, ) this@unconstrainedModuleWriter.unitTest( name = "list_a_unconstrained_converts_into_constrained", test = """ - let c_builder = crate::model::StructureC::builder(); - let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder]); - let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]); + let c_builder = crate::model::StructureC::builder(); + let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder]); + let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]); - let _list_a: crate::constrained::MaybeConstrained = list_a_unconstrained.into(); + let _list_a: crate::constrained::MaybeConstrained = list_a_unconstrained.into(); """, ) - project.compileAndTest() } } + project.compileAndTest() } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt index a5877b7c007..75e176be39d 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt @@ -8,8 +8,9 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.smithy.ConstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.smithy.UnconstrainedModule import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest @@ -29,16 +30,16 @@ class UnconstrainedMapGeneratorTest { key: String, value: MapB } - + map MapB { key: String, value: StructureC } - + structure StructureC { @required int: Integer, - + @required string: String } @@ -49,18 +50,18 @@ class UnconstrainedMapGeneratorTest { val mapA = model.lookup("test#MapA") val mapB = model.lookup("test#MapB") - val project = TestWorkspace.testProject(symbolProvider) + val project = TestWorkspace.testProject(symbolProvider, debugMode = true) - project.withModule(RustModule.public("model")) { + project.withModule(ModelsModule) { model.lookup("test#StructureC").serverRenderWithModelBuilder(model, symbolProvider, this) } - project.withModule(RustModule.private("constrained")) { + project.withModule(ConstrainedModule) { listOf(mapA, mapB).forEach { PubCrateConstrainedMapGenerator(codegenContext, this, it).render() } } - project.withModule(RustModule.private("unconstrained")) unconstrainedModuleWriter@{ + project.withModule(UnconstrainedModule) unconstrainedModuleWriter@{ project.withModule(ModelsModule) modelsModuleWriter@{ listOf(mapA, mapB).forEach { UnconstrainedMapGenerator(codegenContext, this@unconstrainedModuleWriter, it).render() @@ -100,65 +101,65 @@ class UnconstrainedMapGeneratorTest { crate::model::structure_c::ConstraintViolation::MissingInt, ) ); - + let actual_err = crate::constrained::map_a_constrained::MapAConstrained::try_from(map_a_unconstrained).unwrap_err(); assert!(actual_err == missing_string_expected_err || actual_err == missing_int_expected_err); - """, + """, ) this@unconstrainedModuleWriter.unitTest( name = "map_a_unconstrained_succeed_to_constrain", test = """ - let c_builder = crate::model::StructureC::builder().int(69).string(String::from("david")); - let map_b_unconstrained = map_b_unconstrained::MapBUnconstrained( - std::collections::HashMap::from([ - (String::from("KeyB"), c_builder), - ]) - ); - let map_a_unconstrained = map_a_unconstrained::MapAUnconstrained( - std::collections::HashMap::from([ - (String::from("KeyA"), map_b_unconstrained), - ]) - ); - - let expected = std::collections::HashMap::from([ - (String::from("KeyA"), std::collections::HashMap::from([ - (String::from("KeyB"), crate::model::StructureC { - int: 69, - string: String::from("david") - }), - ])) - ]); - - assert_eq!( - expected, - crate::constrained::map_a_constrained::MapAConstrained::try_from(map_a_unconstrained).unwrap().into() - ); + let c_builder = crate::model::StructureC::builder().int(69).string(String::from("david")); + let map_b_unconstrained = map_b_unconstrained::MapBUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyB"), c_builder), + ]) + ); + let map_a_unconstrained = map_a_unconstrained::MapAUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyA"), map_b_unconstrained), + ]) + ); + + let expected = std::collections::HashMap::from([ + (String::from("KeyA"), std::collections::HashMap::from([ + (String::from("KeyB"), crate::model::StructureC { + int: 69, + string: String::from("david") + }), + ])) + ]); + + assert_eq!( + expected, + crate::constrained::map_a_constrained::MapAConstrained::try_from(map_a_unconstrained).unwrap().into() + ); """, ) this@unconstrainedModuleWriter.unitTest( name = "map_a_unconstrained_converts_into_constrained", test = """ - let c_builder = crate::model::StructureC::builder(); - let map_b_unconstrained = map_b_unconstrained::MapBUnconstrained( - std::collections::HashMap::from([ - (String::from("KeyB"), c_builder), - ]) - ); - let map_a_unconstrained = map_a_unconstrained::MapAUnconstrained( - std::collections::HashMap::from([ - (String::from("KeyA"), map_b_unconstrained), - ]) - ); - - let _map_a: crate::constrained::MaybeConstrained = map_a_unconstrained.into(); + let c_builder = crate::model::StructureC::builder(); + let map_b_unconstrained = map_b_unconstrained::MapBUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyB"), c_builder), + ]) + ); + let map_a_unconstrained = map_a_unconstrained::MapAUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyA"), map_b_unconstrained), + ]) + ); + + let _map_a: crate::constrained::MaybeConstrained = map_a_unconstrained.into(); """, ) - - project.compileAndTest() } } + + project.compileAndTest() } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt index f31285de980..4b5eca2d1e4 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt @@ -8,8 +8,8 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.smithy.UnconstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel @@ -29,7 +29,7 @@ class UnconstrainedUnionGeneratorTest { union Union { structure: Structure } - + structure Structure { @required requiredMember: String @@ -42,14 +42,14 @@ class UnconstrainedUnionGeneratorTest { val project = TestWorkspace.testProject(symbolProvider) - project.withModule(RustModule.public("model")) { + project.withModule(ModelsModule) { model.lookup("test#Structure").serverRenderWithModelBuilder(model, symbolProvider, this) } project.withModule(ModelsModule) { UnionGenerator(model, symbolProvider, this, unionShape, renderUnknownVariant = false).render() } - project.withModule(RustModule.private("unconstrained")) unconstrainedModuleWriter@{ + project.withModule(UnconstrainedModule) unconstrainedModuleWriter@{ project.withModule(ModelsModule) modelsModuleWriter@{ UnconstrainedUnionGenerator(codegenContext, this@unconstrainedModuleWriter, this@modelsModuleWriter, unionShape).render() @@ -67,7 +67,7 @@ class UnconstrainedUnionGeneratorTest { expected_err, crate::model::Union::try_from(union_unconstrained).unwrap_err() ); - """, + """, ) this@unconstrainedModuleWriter.unitTest( @@ -82,7 +82,7 @@ class UnconstrainedUnionGeneratorTest { let actual: crate::model::Union = crate::model::Union::try_from(union_unconstrained).unwrap(); assert_eq!(expected, actual); - """, + """, ) this@unconstrainedModuleWriter.unitTest( @@ -93,10 +93,10 @@ class UnconstrainedUnionGeneratorTest { let _union: crate::constrained::MaybeConstrained = union_unconstrained.into(); - """, + """, ) - project.compileAndTest() } } + project.compileAndTest() } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt index 507180d6cc3..38528f44a24 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt @@ -22,6 +22,8 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator @@ -356,7 +358,7 @@ object EventStreamTestTools { } val project = TestWorkspace.testProject(symbolProvider) val operationSymbol = symbolProvider.toSymbol(operationShape) - project.withModule(RustModule.public("error")) { + project.withModule(ErrorsModule) { val errors = model.shapes() .filter { shape -> shape.isStructureShape && shape.hasTrait() } .map { it.asStructureShape().get() } @@ -374,11 +376,11 @@ object EventStreamTestTools { } } } - project.withModule(RustModule.public("model")) { + project.withModule(ModelsModule) { val inputOutput = model.lookup("test#TestStreamInputOutput") recursivelyGenerateModels(model, symbolProvider, inputOutput, this, testCase.target) } - project.withModule(RustModule.public("output")) { + project.withModule(RustModule.Output) { operationShape.outputShape(model).renderWithModelBuilder(model, symbolProvider, this) } return TestEventStreamProject(model, serviceShape, operationShape, unionShape, symbolProvider, project) diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/bin/pokemon-service-connect-info.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/bin/pokemon-service-connect-info.rs index e7422f7db06..8d17756ff14 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/bin/pokemon-service-connect-info.rs +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/bin/pokemon-service-connect-info.rs @@ -3,10 +3,19 @@ * SPDX-License-Identifier: Apache-2.0 */ +use std::net::{IpAddr, SocketAddr}; + +use aws_smithy_http_server::request::connect_info::ConnectInfo; use clap::Parser; use pokemon_service::{ capture_pokemon, check_health, do_nothing, get_pokemon_species, get_server_statistics, setup_tracing, }; +use pokemon_service_server_sdk::{ + error::{GetStorageError, NotAuthorized}, + input::GetStorageInput, + output::GetStorageOutput, + PokemonService, +}; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] @@ -21,20 +30,20 @@ struct Args { /// Retrieves the user's storage. No authentication required for locals. pub async fn get_storage_with_local_approved( - input: pokemon_service_server_sdk::input::GetStorageInput, - connect_info: aws_smithy_http_server::Extension>, -) -> Result { + input: GetStorageInput, + connect_info: ConnectInfo, +) -> Result { tracing::debug!("attempting to authenticate storage user"); - let local = connect_info.0 .0.ip() == "127.0.0.1".parse::().unwrap(); + let local = connect_info.0.ip() == "127.0.0.1".parse::().unwrap(); // We currently support Ash: he has nothing stored if input.user == "ash" && input.passcode == "pikachu123" { - return Ok(pokemon_service_server_sdk::output::GetStorageOutput { collection: vec![] }); + return Ok(GetStorageOutput { collection: vec![] }); } // We support trainers in our gym if local { tracing::info!("welcome back"); - return Ok(pokemon_service_server_sdk::output::GetStorageOutput { + return Ok(GetStorageOutput { collection: vec![ String::from("bulbasaur"), String::from("charmander"), @@ -43,16 +52,14 @@ pub async fn get_storage_with_local_approved( }); } tracing::debug!("authentication failed"); - Err(pokemon_service_server_sdk::error::GetStorageError::NotAuthorized( - pokemon_service_server_sdk::error::NotAuthorized {}, - )) + Err(GetStorageError::NotAuthorized(NotAuthorized {})) } #[tokio::main] async fn main() { let args = Args::parse(); setup_tracing(); - let app = pokemon_service_server_sdk::service::PokemonService::builder_without_plugins() + let app = PokemonService::builder_without_plugins() .get_pokemon_species(get_pokemon_species) .get_storage(get_storage_with_local_approved) .get_server_statistics(get_server_statistics) diff --git a/rust-runtime/aws-smithy-http-server/src/extension.rs b/rust-runtime/aws-smithy-http-server/src/extension.rs index 220b697c78e..3dfc4fc5d5d 100644 --- a/rust-runtime/aws-smithy-http-server/src/extension.rs +++ b/rust-runtime/aws-smithy-http-server/src/extension.rs @@ -3,35 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -// This code was copied and then modified from Tokio's Axum. - -/* Copyright (c) 2021 Tower Contributors - * - * Permission is hereby granted, free of charge, to any - * person obtaining a copy of this software and associated - * documentation files (the "Software"), to deal in the - * Software without restriction, including without - * limitation the rights to use, copy, modify, merge, - * publish, distribute, sublicense, and/or sell copies of - * the Software, and to permit persons to whom the Software - * is furnished to do so, subject to the following - * conditions: - * - * The above copyright notice and this permission notice - * shall be included in all copies or substantial portions - * of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF - * ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED - * TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A - * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT - * SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION - * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR - * IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER - * DEALINGS IN THE SOFTWARE. - */ - //! Extension types. //! //! Extension types are types that are stored in and extracted from _both_ requests and @@ -50,14 +21,12 @@ use std::ops::Deref; -use http::StatusCode; use thiserror::Error; -use crate::{ - body::{empty, BoxBody}, - request::{FromParts, RequestParts}, - response::IntoResponse, -}; +use crate::request::RequestParts; + +pub use crate::request::extension::Extension; +pub use crate::request::extension::MissingExtension; /// Extension type used to store information about Smithy operations in HTTP responses. /// This extension type is set when it has been correctly determined that the request should be @@ -151,49 +120,6 @@ impl Deref for RuntimeErrorExtension { } } -/// Generic extension type stored in and extracted from [request extensions]. -/// -/// This is commonly used to share state across handlers. -/// -/// If the extension is missing it will reject the request with a `500 Internal -/// Server Error` response. -/// -/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html -#[derive(Debug, Clone)] -pub struct Extension(pub T); - -impl Deref for Extension { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -/// The extension has not been added to the [`Request`](http::Request) or has been previously removed. -#[derive(Debug, Error)] -#[error("the `Extension` is not present in the `http::Request`")] -pub struct MissingExtension; - -impl IntoResponse for MissingExtension { - fn into_response(self) -> http::Response { - let mut response = http::Response::new(empty()); - *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - response - } -} - -impl FromParts for Extension -where - T: Send + Sync + 'static, -{ - type Rejection = MissingExtension; - - fn from_parts(parts: &mut http::request::Parts) -> Result { - parts.extensions.remove::().map(Extension).ok_or(MissingExtension) - } -} - /// Extract an [`Extension`] from a request. /// This is essentially the implementation of `FromRequest` for `Extension`, but with a /// protocol-agnostic rejection type. The actual code-generated implementation simply delegates to diff --git a/rust-runtime/aws-smithy-http-server/src/lib.rs b/rust-runtime/aws-smithy-http-server/src/lib.rs index 18ae005d08a..64f512a905e 100644 --- a/rust-runtime/aws-smithy-http-server/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/src/lib.rs @@ -36,7 +36,7 @@ pub mod routers; #[doc(inline)] pub(crate) use self::error::Error; -pub use self::extension::Extension; +pub use self::request::extension::Extension; #[doc(inline)] pub use self::routing::Router; #[doc(inline)] diff --git a/rust-runtime/aws-smithy-http-server/src/plugin/mod.rs b/rust-runtime/aws-smithy-http-server/src/plugin/mod.rs index c6c13aa6c02..a41faa3c35a 100644 --- a/rust-runtime/aws-smithy-http-server/src/plugin/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/plugin/mod.rs @@ -3,6 +3,53 @@ * SPDX-License-Identifier: Apache-2.0 */ +//! The plugin system allows you to build middleware with an awareness of the operation it is applied to. +//! +//! The system centers around the [`Plugin`] trait. In addition, this module provides helpers for composing and +//! combining [`Plugin`]s. +//! +//! # Filtered application of a HTTP [`Layer`](tower::Layer) +//! +//! ``` +//! # use aws_smithy_http_server::plugin::*; +//! # let layer = (); +//! # struct GetPokemonSpecies; +//! # impl GetPokemonSpecies { const NAME: &'static str = ""; }; +//! // Create a `Plugin` from a HTTP `Layer` +//! let plugin = HttpLayer(layer); +//! +//! // Only apply the layer to operations with name "GetPokemonSpecies" +//! let plugin = filter_by_operation_name(plugin, |name| name == GetPokemonSpecies::NAME); +//! ``` +//! +//! # Construct a [`Plugin`] from a closure that takes as input the operation name +//! +//! ``` +//! # use aws_smithy_http_server::plugin::*; +//! // A `tower::Layer` which requires the operation name +//! struct PrintLayer { +//! name: &'static str +//! } +//! +//! // Create a `Plugin` using `PrintLayer` +//! let plugin = plugin_from_operation_name_fn(|name| PrintLayer { name }); +//! ``` +//! +//! # Combine [`Plugin`]s +//! +//! ``` +//! # use aws_smithy_http_server::plugin::*; +//! # let a = (); let b = (); +//! // Combine `Plugin`s `a` and `b` +//! let plugin = PluginPipeline::new() +//! .push(a) +//! .push(b); +//! ``` +//! +//! As noted in the [`PluginPipeline`] documentation, the plugins' runtime logic is executed in registration order, +//! meaning that `a` is run _before_ `b` in the example above. +//! + mod closure; mod filter; mod identity; diff --git a/rust-runtime/aws-smithy-http-server/src/routing/into_make_service_with_connect_info.rs b/rust-runtime/aws-smithy-http-server/src/request/connect_info.rs similarity index 84% rename from rust-runtime/aws-smithy-http-server/src/routing/into_make_service_with_connect_info.rs rename to rust-runtime/aws-smithy-http-server/src/request/connect_info.rs index 8bd75238efe..fd66475fc9b 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/into_make_service_with_connect_info.rs +++ b/rust-runtime/aws-smithy-http-server/src/request/connect_info.rs @@ -32,6 +32,8 @@ * DEALINGS IN THE SOFTWARE. */ +//! Extractor for getting connection information from a client. + use std::{ convert::Infallible, fmt, @@ -48,12 +50,11 @@ use tower_http::add_extension::{AddExtension, AddExtensionLayer}; use crate::{request::FromParts, Extension}; -/// A [`MakeService`] created from a router. +/// A [`MakeService`] used to insert [`ConnectInfo`] into [`http::Request`]s. /// -/// See [`Router::into_make_service_with_connect_info`] for more details. +/// The `T` must be derivable from the underlying IO resource using the [`Connected`] trait. /// /// [`MakeService`]: tower::make::MakeService -/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info pub struct IntoMakeServiceWithConnectInfo { inner: S, _connect_info: PhantomData C>, @@ -96,10 +97,6 @@ where /// /// The goal for this trait is to allow users to implement custom IO types that /// can still provide the same connection metadata. -/// -/// See [`Router::into_make_service_with_connect_info`] for more details. -/// -/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info pub trait Connected: Clone { /// Create type holding information about the connection. fn connect_info(target: T) -> Self; @@ -140,13 +137,9 @@ opaque_future! { /// Extractor for getting connection information produced by a `Connected`. /// -/// Note this extractor requires you to use -/// [`Router::into_make_service_with_connect_info`] to run your app -/// otherwise it will fail at runtime. -/// -/// See [`Router::into_make_service_with_connect_info`] for more details. -/// -/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info +/// Note this extractor requires the existence of [`Extension>`] in the [`http::Extensions`]. This is +/// automatically inserted by the [`IntoMakeServiceWithConnectInfo`] middleware, which can be applied using the +/// `into_make_service_with_connect_info` method on your generated service. #[derive(Clone, Debug)] pub struct ConnectInfo(pub T); diff --git a/rust-runtime/aws-smithy-http-server/src/request/extension.rs b/rust-runtime/aws-smithy-http-server/src/request/extension.rs new file mode 100644 index 00000000000..7aaf23ff2cc --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/request/extension.rs @@ -0,0 +1,103 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +// This code was copied and then modified from Tokio's Axum. + +/* Copyright (c) 2021 Tower Contributors + * + * Permission is hereby granted, free of charge, to any + * person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the + * Software without restriction, including without + * limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software + * is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice + * shall be included in all copies or substantial portions + * of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF + * ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED + * TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A + * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT + * SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR + * IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +//! Extension types. +//! +//! Extension types are types that are stored in and extracted from _both_ requests and +//! responses. +//! +//! There is only one _generic_ extension type _for requests_, [`Extension`]. +//! +//! On the other hand, the server SDK uses multiple concrete extension types for responses in order +//! to store a variety of information, like the operation that was executed, the operation error +//! that got returned, or the runtime error that happened, among others. The information stored in +//! these types may be useful to [`tower::Layer`]s that post-process the response: for instance, a +//! particular metrics layer implementation might want to emit metrics about the number of times an +//! an operation got executed. +//! +//! [extensions]: https://docs.rs/http/latest/http/struct.Extensions.html + +use std::ops::Deref; + +use http::StatusCode; +use thiserror::Error; + +use crate::{ + body::{empty, BoxBody}, + request::FromParts, + response::IntoResponse, +}; + +/// Generic extension type stored in and extracted from [request extensions]. +/// +/// This is commonly used to share state across handlers. +/// +/// If the extension is missing it will reject the request with a `500 Internal +/// Server Error` response. +/// +/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html +#[derive(Debug, Clone)] +pub struct Extension(pub T); + +impl Deref for Extension { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +/// The extension has not been added to the [`Request`](http::Request) or has been previously removed. +#[derive(Debug, Error)] +#[error("the `Extension` is not present in the `http::Request`")] +pub struct MissingExtension; + +impl IntoResponse for MissingExtension { + fn into_response(self) -> http::Response { + let mut response = http::Response::new(empty()); + *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + response + } +} + +impl FromParts for Extension +where + T: Send + Sync + 'static, +{ + type Rejection = MissingExtension; + + fn from_parts(parts: &mut http::request::Parts) -> Result { + parts.extensions.remove::().map(Extension).ok_or(MissingExtension) + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/request.rs b/rust-runtime/aws-smithy-http-server/src/request/mod.rs similarity index 93% rename from rust-runtime/aws-smithy-http-server/src/request.rs rename to rust-runtime/aws-smithy-http-server/src/request/mod.rs index 9faf1ea58d3..efcf3c60414 100644 --- a/rust-runtime/aws-smithy-http-server/src/request.rs +++ b/rust-runtime/aws-smithy-http-server/src/request/mod.rs @@ -32,6 +32,11 @@ * DEALINGS IN THE SOFTWARE. */ +//! Types and traits for extracting data from requests. +//! +//! See [Accessing Un-modelled data](https://github.com/awslabs/smithy-rs/blob/main/design/src/server/from_parts.md) +//! a comprehensive overview. + use std::{ convert::Infallible, future::{ready, Future, Ready}, @@ -45,6 +50,9 @@ use http::{request::Parts, Extensions, HeaderMap, Request, Uri}; use crate::{rejection::any_rejections, response::IntoResponse}; +pub mod connect_info; +pub mod extension; + #[doc(hidden)] #[derive(Debug)] pub struct RequestParts { @@ -111,8 +119,8 @@ impl RequestParts { } } -/// Provides a protocol aware extraction from a [`Request`]. This borrows the -/// [`Parts`], in contrast to [`FromRequest`]. +// NOTE: We cannot reference `FromRequest` here, as a point of contrast, as it's `doc(hidden)`. +/// Provides a protocol aware extraction from a requests [`Parts`]. pub trait FromParts: Sized { type Rejection: IntoResponse; diff --git a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs index 1c84d794cf2..1cd3dd21e75 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs @@ -29,7 +29,6 @@ use tower_http::map_response_body::MapResponseBodyLayer; mod future; mod into_make_service; -mod into_make_service_with_connect_info; mod lambda_handler; #[doc(hidden)] @@ -40,10 +39,7 @@ mod route; pub(crate) mod tiny_map; pub use self::lambda_handler::LambdaHandler; -pub use self::{ - future::RouterFuture, into_make_service::IntoMakeService, into_make_service_with_connect_info::ConnectInfo, - into_make_service_with_connect_info::IntoMakeServiceWithConnectInfo, route::Route, -}; +pub use self::{future::RouterFuture, into_make_service::IntoMakeService, route::Route}; /// The router is a [`tower::Service`] that routes incoming requests to other `Service`s /// based on the request's URI and HTTP method or on some specific header setting the target operation. @@ -120,18 +116,6 @@ where IntoMakeService::new(self) } - /// Convert this router into a [`MakeService`], that is a [`Service`] whose - /// response is another service, and provides a [`ConnectInfo`] object to service handlers. - /// - /// This is useful when running your application with hyper's - /// [`Server`]. - /// - /// [`Server`]: hyper::server::Server - /// [`MakeService`]: tower::make::MakeService - pub fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo { - IntoMakeServiceWithConnectInfo::new(self) - } - /// Apply a [`tower::Layer`] to the router. /// /// All requests to the router will be processed by the layer's diff --git a/rust-runtime/aws-smithy-http-server/src/routing/route.rs b/rust-runtime/aws-smithy-http-server/src/routing/route.rs index 8964c9629f4..67d7ec53530 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/route.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/route.rs @@ -46,12 +46,15 @@ use tower::{ Service, ServiceExt, }; -/// How routes are stored inside a [`Router`](super::Router). +/// A HTTP [`Service`] representing a single route. +/// +/// The construction of [`Route`] from a named HTTP [`Service`] `S`, erases the type of `S`. pub struct Route { service: BoxCloneService, Response, Infallible>, } impl Route { + /// Constructs a new [`Route`] from a well-formed HTTP service which is cloneable. pub fn new(svc: T) -> Self where T: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, diff --git a/tools/ci-cdk/README.md b/tools/ci-cdk/README.md index d126045c3f2..31c4a2fc6ef 100644 --- a/tools/ci-cdk/README.md +++ b/tools/ci-cdk/README.md @@ -7,7 +7,7 @@ The `cdk.json` file tells the CDK Toolkit how to synthesize the infrastructure. ## Canary local development -Sometimes it's useful to only deploy the the canary resources to a test AWS account to iterate +Sometimes it's useful to only deploy the canary resources to a test AWS account to iterate on the `canary-runner` and `canary-lambda`. To do this, run the following: ```bash @@ -21,10 +21,10 @@ From there, you can just point the `canary-runner` to the `cdk-outputs.json` to ```bash cd canary-runner -cargo run -- --sdk-version --musl --cdk-outputs ../cdk-outputs.json +cargo run -- run --sdk-release-tag --musl --cdk-outputs ../cdk-outputs.json ``` -__NOTE:__ You may want to add a `--profile` to the deploy command to select a specific credential +__NOTE:__ You may want to add a `--profile` to the `deploy` command to select a specific credential profile to deploy to if you don't want to use the default. Also, if this is a new test AWS account, be sure it CDK bootstrap it before attempting to deploy. diff --git a/tools/ci-cdk/canary-runner/src/build_bundle.rs b/tools/ci-cdk/canary-runner/src/build_bundle.rs index 7a47c96baee..4f31d75d9c1 100644 --- a/tools/ci-cdk/canary-runner/src/build_bundle.rs +++ b/tools/ci-cdk/canary-runner/src/build_bundle.rs @@ -3,6 +3,13 @@ * SPDX-License-Identifier: Apache-2.0 */ +use std::fmt::Write as FmtWrite; +use std::fs; +use std::io::Write as IoWrite; +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::str::FromStr; + use anyhow::{bail, Context, Result}; use clap::Parser; use lazy_static::lazy_static; @@ -11,12 +18,6 @@ use smithy_rs_tool_common::here; use smithy_rs_tool_common::release_tag::ReleaseTag; use smithy_rs_tool_common::shell::handle_failure; use smithy_rs_tool_common::versions_manifest::VersionsManifest; -use std::fmt::Write as FmtWrite; -use std::fs; -use std::io::Write as IoWrite; -use std::path::{Path, PathBuf}; -use std::process::Command; -use std::str::FromStr; const BASE_MANIFEST: &str = r#" # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. @@ -251,12 +252,14 @@ pub async fn build_bundle(opt: BuildBundleArgs) -> Result> { #[cfg(test)] mod tests { - use super::*; - use crate::Args; use clap::Parser; use smithy_rs_tool_common::package::PackageCategory; use smithy_rs_tool_common::versions_manifest::CrateVersion; + use crate::Args; + + use super::*; + #[test] fn test_arg_parsing() { assert!(Args::try_parse_from(["./canary-runner", "build-bundle"]).is_err()); diff --git a/tools/ci-cdk/canary-runner/src/run.rs b/tools/ci-cdk/canary-runner/src/run.rs index 19d6048720d..c5bafc6c165 100644 --- a/tools/ci-cdk/canary-runner/src/run.rs +++ b/tools/ci-cdk/canary-runner/src/run.rs @@ -14,11 +14,12 @@ // CAUTION: This subcommand will `git reset --hard` in some cases. Don't ever run // it against a smithy-rs repo that you're actively working in. -use crate::build_bundle::BuildBundleArgs; +use std::path::PathBuf; +use std::str::FromStr; +use std::time::{Duration, SystemTime}; +use std::{env, path::Path}; + use anyhow::{bail, Context, Result}; -use aws_sdk_cloudwatch as cloudwatch; -use aws_sdk_lambda as lambda; -use aws_sdk_s3 as s3; use clap::Parser; use cloudwatch::model::StandardUnit; use s3::types::ByteStream; @@ -26,12 +27,13 @@ use serde::Deserialize; use smithy_rs_tool_common::git::{find_git_repository_root, Git, GitCLI}; use smithy_rs_tool_common::macros::here; use smithy_rs_tool_common::release_tag::ReleaseTag; -use std::path::PathBuf; -use std::str::FromStr; -use std::time::{Duration, SystemTime}; -use std::{env, path::Path}; use tracing::info; +use crate::build_bundle::BuildBundleArgs; + +use aws_sdk_cloudwatch as cloudwatch; +use aws_sdk_lambda as lambda; +use aws_sdk_s3 as s3; lazy_static::lazy_static! { // Occasionally, a breaking change introduced in smithy-rs will cause the canary to fail // for older versions of the SDK since the canary is in the smithy-rs repository and will @@ -45,6 +47,9 @@ lazy_static::lazy_static! { // Versions <= 0.6.0 no longer compile against the canary after this commit in smithy-rs // due to the breaking change in https://github.com/awslabs/smithy-rs/pull/1085 (ReleaseTag::from_str("v0.6.0").unwrap(), "d48c234796a16d518ca9e1dda5c7a1da4904318c"), + // Versions <= release-2022-10-26 no longer compile against the canary after this commit in smithy-rs + // due to the s3 canary update in https://github.com/awslabs/smithy-rs/pull/1974 + (ReleaseTag::from_str("release-2022-10-26").unwrap(), "3e24477ae7a0a2b3853962a064bc8333a016af54") ]; pinned.sort(); pinned