Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Smithy to 1.16.1 #1053

Merged
merged 11 commits into from
Jan 12, 2022
Merged
6 changes: 6 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false }
# author = "rcoh"

[[smithy-rs]]
message = "Upgraded Smithy to 1.16.1"
references = ["smithy-rs#1053"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "jdisanti"
4 changes: 0 additions & 4 deletions aws/sdk-codegen-test/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ plugins {

val smithyVersion: String by project


dependencies {
implementation(project(":aws:sdk-codegen"))
implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion")
Expand Down Expand Up @@ -55,15 +54,13 @@ fun generateSmithyBuild(tests: List<CodegenTest>): String {
"""
}


task("generateSmithyBuild") {
description = "generate smithy-build.json"
doFirst {
projectDir.resolve("smithy-build.json").writeText(generateSmithyBuild(CodegenTests))
}
}


fun generateCargoWorkspace(tests: List<CodegenTest>): String {
return """
[workspace]
Expand All @@ -82,7 +79,6 @@ task("generateCargoWorkspace") {
tasks["smithyBuildJar"].dependsOn("generateSmithyBuild")
tasks["assemble"].finalizedBy("generateCargoWorkspace")


tasks.register<Exec>("cargoCheck") {
workingDir("build/smithyprojections/sdk-codegen-test/")
// disallow warnings
Expand Down
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies {
}

val lintPaths = listOf(
"codegen/src/**/*.kt"
"codegen/src/**/*.kt"
)

tasks.register<JavaExec>("ktlint") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolContentTypes
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.restJsonFieldName
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator

Expand Down Expand Up @@ -75,11 +76,11 @@ class ServerRestJson(private val codegenContext: CodegenContext) : Protocol {
override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
return JsonParserGenerator(codegenContext, httpBindingResolver)
return JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)
}

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator {
return JsonSerializerGenerator(codegenContext, httpBindingResolver)
return JsonSerializerGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)
}

// NOTE: this method is only needed for the little part of client-codegen we use in tests.
Expand Down
26 changes: 26 additions & 0 deletions codegen-test/model/rest-json-extras.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,30 @@ use aws.api#service
use smithy.test#httpRequestTests
use smithy.test#httpResponseTests

// TODO(https://github.com/awslabs/smithy/pull/1042): Remove this once the test case in Smithy is fixed
apply PostPlayerAction @httpRequestTests([
{
id: "FIXED_RestJsonInputUnionWithUnitMember",
documentation: "Unit types in unions are serialized like normal structures in requests.",
protocol: restJson1,
method: "POST",
"uri": "/PostPlayerInput",
body: """
{
"action": {
"quit": {}
}
}""",
bodyMediaType: "application/json",
headers: {"Content-Type": "application/json"},
params: {
action: {
quit: {}
}
}
}
])

apply QueryPrecedence @httpRequestTests([
{
id: "UrlParamsKeyEncoding",
Expand Down Expand Up @@ -64,6 +88,8 @@ service RestJsonExtras {
NullInNonSparse,
CaseInsensitiveErrorOperation,
EmptyStructWithContentOnWireOp,
// TODO(https://github.com/awslabs/smithy/pull/1042): Remove this once the test case in Smithy is fixed
PostPlayerAction
],
errors: [ExtraError]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class RequestBindingGenerator(
) {
private val index = HttpBindingIndex.of(model)
private val Encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder")
private val headerUtil = CargoDependency.SmithyHttp(runtimeConfig).asType().member("header")

private val codegenScope = arrayOf(
"BuildError" to runtimeConfig.operationBuildError(),
Expand Down Expand Up @@ -175,6 +176,7 @@ class RequestBindingGenerator(
else -> UNREACHABLE("unexpected member for prefix headers: $memberType")
}
ifSet(memberType, memberSymbol, "&_input.$memberName") { field ->
val listHeader = memberType is CollectionShape
rustTemplate(
"""
for (k, v) in $field {
Expand All @@ -183,8 +185,8 @@ class RequestBindingGenerator(
#{build_error}::InvalidField { field: ${memberName.dq()}, details: format!("`{}` cannot be used as a header name: {}", k, err)}
})?;
use std::convert::TryFrom;
let header_value = ${headerFmtFun(this, target, memberShape, "v")};
let header_value = http::header::HeaderValue::try_from(header_value).map_err(|err| {
let header_value = ${headerFmtFun(this, target, memberShape, "v", listHeader)};
let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| {
#{build_error}::InvalidField {
field: ${memberName.dq()},
details: format!("`{}` cannot be used as a header value: {}", ${
Expand All @@ -210,12 +212,13 @@ class RequestBindingGenerator(
val memberSymbol = symbolProvider.toSymbol(memberShape)
val memberName = symbolProvider.toMemberName(memberShape)
ifSet(memberType, memberSymbol, "&_input.$memberName") { field ->
val listHeader = memberType is CollectionShape
jdisanti marked this conversation as resolved.
Show resolved Hide resolved
listForEach(memberType, field) { innerField, targetId ->
val innerMemberType = model.expectShape(targetId)
if (innerMemberType.isPrimitive()) {
rust("let mut encoder = #T::from(${autoDeref(innerField)});", Encoder)
}
val formatted = headerFmtFun(this, innerMemberType, memberShape, innerField)
val formatted = headerFmtFun(this, innerMemberType, memberShape, innerField, listHeader)
val safeName = safeName("formatted")
write("let $safeName = $formatted;")
rustBlock("if !$safeName.is_empty()") {
Expand Down Expand Up @@ -244,21 +247,30 @@ class RequestBindingGenerator(
/**
* Format [member] in the when used as an HTTP header
*/
private fun headerFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String): String {
private fun headerFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String, listHeader: Boolean): String {
fun quoteValue(value: String): String {
// Timestamp shapes are not quoted in header lists
return if (listHeader && !target.isTimestampShape) {
val quoteFn = writer.format(headerUtil.member("quote_header_value"))
"$quoteFn($value)"
} else {
value
}
}
return when {
target.isStringShape -> {
if (target.hasTrait<MediaTypeTrait>()) {
val func = writer.format(RuntimeType.Base64Encode(runtimeConfig))
"$func(&$targetName)"
} else {
"AsRef::<str>::as_ref($targetName)"
quoteValue("AsRef::<str>::as_ref($targetName)")
}
}
target.isTimestampShape -> {
val timestampFormat =
index.determineTimestampFormat(member, HttpBinding.Location.HEADER, defaultTimestampFormat)
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
"$targetName.fmt(${writer.format(timestampFormatType)})?"
quoteValue("$targetName.fmt(${writer.format(timestampFormatType)})?")
}
target.isListShape || target.isMemberShape -> {
throw IllegalArgumentException("lists should be handled at a higher level")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,11 @@ class ProtocolTestGenerator(
private val RestXml = "aws.protocoltests.restxml#RestXml"
private val AwsQuery = "aws.protocoltests.query#AwsQuery"
private val Ec2Query = "aws.protocoltests.ec2#AwsEc2"
private val ExpectFail = setOf<FailingTest>()
private val ExpectFail = setOf<FailingTest>(
// TODO(https://github.com/awslabs/smithy/pull/1042): Remove this once the test case in Smithy is fixed
FailingTest(RestJson, "RestJsonInputUnionWithUnitMember", Action.Request),
FailingTest("${RestJson}Extras", "RestJsonInputUnionWithUnitMember", Action.Request),
)
private val RunOnly: Set<String>? = null

// These tests are not even attempted to be generated, either because they will not compile
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.smithy.protocols

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.pattern.UriPattern
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
Expand All @@ -25,7 +26,6 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredData
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.orNull

sealed class AwsJsonVersion {
abstract val value: String
Expand Down Expand Up @@ -72,19 +72,18 @@ class AwsJsonHttpBindingResolver(
.uri(UriPattern.parse("/"))
.build()

private fun bindings(shape: ToShapeId?) =
shape?.let { model.expectShape(it.toShapeId()) }?.members()
?.map { HttpBindingDescriptor(it, HttpLocation.DOCUMENT, "document") }
?.toList()
?: emptyList()
private fun bindings(shape: ToShapeId) =
shape.let { model.expectShape(it.toShapeId()) }.members()
.map { HttpBindingDescriptor(it, HttpLocation.DOCUMENT, "document") }
.toList()

override fun httpTrait(operationShape: OperationShape): HttpTrait = httpTrait

override fun requestBindings(operationShape: OperationShape): List<HttpBindingDescriptor> =
bindings(operationShape.input.orNull())
bindings(operationShape.inputShape)

override fun responseBindings(operationShape: OperationShape): List<HttpBindingDescriptor> =
bindings(operationShape.output.orNull())
bindings(operationShape.outputShape)

override fun errorResponseBindings(errorShape: ToShapeId): List<HttpBindingDescriptor> =
bindings(errorShape)
Expand All @@ -103,7 +102,7 @@ class AwsJsonSerializerGenerator(
private val codegenContext: CodegenContext,
httpBindingResolver: HttpBindingResolver,
private val jsonSerializerGenerator: JsonSerializerGenerator =
JsonSerializerGenerator(codegenContext, httpBindingResolver)
JsonSerializerGenerator(codegenContext, httpBindingResolver, ::awsJsonFieldName)
) : StructuredDataSerializerGenerator by jsonSerializerGenerator {
private val runtimeConfig = codegenContext.runtimeConfig
private val codegenScope = arrayOf(
Expand Down Expand Up @@ -153,7 +152,7 @@ class AwsJson(
listOf("x-amz-target" to "${codegenContext.serviceShape.id.name}.${operationShape.id.name}")

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator =
JsonParserGenerator(codegenContext, httpBindingResolver)
JsonParserGenerator(codegenContext, httpBindingResolver, ::awsJsonFieldName)

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
AwsJsonSerializerGenerator(codegenContext, httpBindingResolver)
Expand Down Expand Up @@ -183,3 +182,7 @@ class AwsJson(
)
}
}

private fun awsJsonFieldName(member: MemberShape): String {
return member.memberName
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package software.amazon.smithy.rust.codegen.smithy.protocols

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.JsonNameTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustModule
Expand All @@ -19,6 +21,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGene
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.util.getTrait

class RestJsonFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> {
override fun protocol(codegenContext: CodegenContext): Protocol = RestJson(codegenContext)
Expand Down Expand Up @@ -61,13 +64,11 @@ class RestJson(private val codegenContext: CodegenContext) : Protocol {

override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
return JsonParserGenerator(codegenContext, httpBindingResolver)
}
override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator =
JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator {
return JsonSerializerGenerator(codegenContext, httpBindingResolver)
}
override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
JsonSerializerGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)

override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType =
RuntimeType.forInlineFun("parse_http_generic_error", jsonDeserModule) { writer ->
Expand All @@ -94,3 +95,7 @@ class RestJson(private val codegenContext: CodegenContext) : Protocol {
)
}
}

fun restJsonFieldName(member: MemberShape): String {
return member.getTrait<JsonNameTrait>()?.value ?: member.memberName
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.JsonNameTrait
import software.amazon.smithy.model.traits.SparseTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
Expand Down Expand Up @@ -48,7 +47,6 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.deserializeFunctionName
import software.amazon.smithy.rust.codegen.util.PANIC
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape
Expand All @@ -57,6 +55,8 @@ import software.amazon.smithy.utils.StringUtils
class JsonParserGenerator(
codegenContext: CodegenContext,
private val httpBindingResolver: HttpBindingResolver,
/** Function that maps a MemberShape into a JSON field name */
private val jsonName: (MemberShape) -> String,
) : StructuredDataParserGenerator {
private val model = codegenContext.model
private val symbolProvider = codegenContext.symbolProvider
Expand Down Expand Up @@ -220,7 +220,7 @@ class JsonParserGenerator(
objectKeyLoop(hasMembers = members.isNotEmpty()) {
rustBlock("match key.to_unescaped()?.as_ref()") {
for (member in members) {
rustBlock("${member.wireName().dq()} =>") {
rustBlock("${jsonName(member).dq()} =>") {
withBlock("builder = builder.${member.setterName()}(", ");") {
deserializeMember(member)
}
Expand Down Expand Up @@ -430,7 +430,7 @@ class JsonParserGenerator(
withBlock("variant = match key.to_unescaped()?.as_ref() {", "};") {
for (member in shape.members()) {
val variantName = symbolProvider.toMemberName(member)
rustBlock("${member.wireName().dq()} =>") {
rustBlock("${jsonName(member).dq()} =>") {
withBlock("Some(#T::$variantName(", "))", symbol) {
deserializeMember(member)
unwrapOrDefaultOrError(member)
Expand Down Expand Up @@ -524,6 +524,4 @@ class JsonParserGenerator(
}
}
}

private fun MemberShape.wireName(): String = getTrait<JsonNameTrait>()?.value ?: memberName
}
Loading