Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Smithy to 1.16.1 #1053

Merged
merged 11 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ message = "Add support for SSO credentials"
references = ["smithy-rs#1051", "aws-sdk-rust#4"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
author = "rcoh"

[[smithy-rs]]
message = "Upgraded Smithy to 1.16.1"
references = ["smithy-rs#1053"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "jdisanti"
4 changes: 0 additions & 4 deletions aws/sdk-codegen-test/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ plugins {

val smithyVersion: String by project


dependencies {
implementation(project(":aws:sdk-codegen"))
implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion")
Expand Down Expand Up @@ -55,15 +54,13 @@ fun generateSmithyBuild(tests: List<CodegenTest>): String {
"""
}


task("generateSmithyBuild") {
description = "generate smithy-build.json"
doFirst {
projectDir.resolve("smithy-build.json").writeText(generateSmithyBuild(CodegenTests))
}
}


fun generateCargoWorkspace(tests: List<CodegenTest>): String {
return """
[workspace]
Expand All @@ -82,7 +79,6 @@ task("generateCargoWorkspace") {
tasks["smithyBuildJar"].dependsOn("generateSmithyBuild")
tasks["assemble"].finalizedBy("generateCargoWorkspace")


tasks.register<Exec>("cargoCheck") {
workingDir("build/smithyprojections/sdk-codegen-test/")
// disallow warnings
Expand Down
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies {
}

val lintPaths = listOf(
"codegen/src/**/*.kt"
"codegen/src/**/*.kt"
)

tasks.register<JavaExec>("ktlint") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolContentTypes
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.restJsonFieldName
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator

Expand Down Expand Up @@ -75,11 +76,11 @@ class ServerRestJson(private val codegenContext: CodegenContext) : Protocol {
override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
return JsonParserGenerator(codegenContext, httpBindingResolver)
return JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)
}

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator {
return JsonSerializerGenerator(codegenContext, httpBindingResolver)
return JsonSerializerGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)
}

// NOTE: this method is only needed for the little part of client-codegen we use in tests.
Expand Down
42 changes: 42 additions & 0 deletions codegen-test/model/rest-json-extras.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,46 @@ use aws.api#service
use smithy.test#httpRequestTests
use smithy.test#httpResponseTests

// TODO(https://github.com/awslabs/smithy/pull/1049): Remove this once the test case in Smithy is fixed
apply InputAndOutputWithHeaders @httpResponseTests([
{
id: "FIXED_RestJsonInputAndOutputWithQuotedStringHeaders",
documentation: "Tests responses with string list header bindings that require quoting",
protocol: restJson1,
code: 200,
headers: {
"X-StringList": "\"b,c\", \"\\\"def\\\"\", a"
},
params: {
headerStringList: ["b,c", "\"def\"", "a"]
}
}
])

// TODO(https://github.com/awslabs/smithy/pull/1042): Remove this once the test case in Smithy is fixed
apply PostPlayerAction @httpRequestTests([
{
id: "FIXED_RestJsonInputUnionWithUnitMember",
documentation: "Unit types in unions are serialized like normal structures in requests.",
protocol: restJson1,
method: "POST",
"uri": "/PostPlayerInput",
body: """
{
"action": {
"quit": {}
}
}""",
bodyMediaType: "application/json",
headers: {"Content-Type": "application/json"},
params: {
action: {
quit: {}
}
}
}
])

apply QueryPrecedence @httpRequestTests([
{
id: "UrlParamsKeyEncoding",
Expand Down Expand Up @@ -64,6 +104,8 @@ service RestJsonExtras {
NullInNonSparse,
CaseInsensitiveErrorOperation,
EmptyStructWithContentOnWireOp,
// TODO(https://github.com/awslabs/smithy/pull/1042): Remove this once the test case in Smithy is fixed
PostPlayerAction
],
errors: [ExtraError]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class RequestBindingGenerator(
) {
private val index = HttpBindingIndex.of(model)
private val Encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder")
private val headerUtil = CargoDependency.SmithyHttp(runtimeConfig).asType().member("header")

private val codegenScope = arrayOf(
"BuildError" to runtimeConfig.operationBuildError(),
Expand Down Expand Up @@ -175,6 +176,7 @@ class RequestBindingGenerator(
else -> UNREACHABLE("unexpected member for prefix headers: $memberType")
}
ifSet(memberType, memberSymbol, "&_input.$memberName") { field ->
val listHeader = memberType is CollectionShape
rustTemplate(
"""
for (k, v) in $field {
Expand All @@ -183,8 +185,8 @@ class RequestBindingGenerator(
#{build_error}::InvalidField { field: ${memberName.dq()}, details: format!("`{}` cannot be used as a header name: {}", k, err)}
})?;
use std::convert::TryFrom;
let header_value = ${headerFmtFun(this, target, memberShape, "v")};
let header_value = http::header::HeaderValue::try_from(header_value).map_err(|err| {
let header_value = ${headerFmtFun(this, target, memberShape, "v", listHeader)};
let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| {
#{build_error}::InvalidField {
field: ${memberName.dq()},
details: format!("`{}` cannot be used as a header value: {}", ${
Expand All @@ -210,12 +212,13 @@ class RequestBindingGenerator(
val memberSymbol = symbolProvider.toSymbol(memberShape)
val memberName = symbolProvider.toMemberName(memberShape)
ifSet(memberType, memberSymbol, "&_input.$memberName") { field ->
val isListHeader = memberType is CollectionShape
listForEach(memberType, field) { innerField, targetId ->
val innerMemberType = model.expectShape(targetId)
if (innerMemberType.isPrimitive()) {
rust("let mut encoder = #T::from(${autoDeref(innerField)});", Encoder)
}
val formatted = headerFmtFun(this, innerMemberType, memberShape, innerField)
val formatted = headerFmtFun(this, innerMemberType, memberShape, innerField, isListHeader)
val safeName = safeName("formatted")
write("let $safeName = $formatted;")
rustBlock("if !$safeName.is_empty()") {
Expand Down Expand Up @@ -244,21 +247,30 @@ class RequestBindingGenerator(
/**
* Format [member] in the when used as an HTTP header
*/
private fun headerFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String): String {
private fun headerFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String, isListHeader: Boolean): String {
fun quoteValue(value: String): String {
// Timestamp shapes are not quoted in header lists
return if (isListHeader && !target.isTimestampShape) {
val quoteFn = writer.format(headerUtil.member("quote_header_value"))
"$quoteFn($value)"
} else {
value
}
}
return when {
target.isStringShape -> {
if (target.hasTrait<MediaTypeTrait>()) {
val func = writer.format(RuntimeType.Base64Encode(runtimeConfig))
"$func(&$targetName)"
} else {
"AsRef::<str>::as_ref($targetName)"
quoteValue("AsRef::<str>::as_ref($targetName)")
}
}
target.isTimestampShape -> {
val timestampFormat =
index.determineTimestampFormat(member, HttpBinding.Location.HEADER, defaultTimestampFormat)
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
"$targetName.fmt(${writer.format(timestampFormatType)})?"
quoteValue("$targetName.fmt(${writer.format(timestampFormatType)})?")
}
target.isListShape || target.isMemberShape -> {
throw IllegalArgumentException("lists should be handled at a higher level")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@ class ProtocolTestGenerator(
rust(
"""
assert_eq!(
parsed.$memberName.collect().await.unwrap().into_bytes(),
expected_output.$memberName.collect().await.unwrap().into_bytes()
);
parsed.$memberName.collect().await.unwrap().into_bytes(),
expected_output.$memberName.collect().await.unwrap().into_bytes()
);
"""
)
} else {
Expand Down Expand Up @@ -367,7 +367,7 @@ class ProtocolTestGenerator(
}
val variableName = "expected_headers"
rustWriter.withBlock("let $variableName = [", "];") {
write(
writeWithNoFormatting(
headers.entries.joinToString(",") {
"(${it.key.dq()}, ${it.value.dq()})"
}
Expand Down Expand Up @@ -450,7 +450,13 @@ class ProtocolTestGenerator(
private val RestXml = "aws.protocoltests.restxml#RestXml"
private val AwsQuery = "aws.protocoltests.query#AwsQuery"
private val Ec2Query = "aws.protocoltests.ec2#AwsEc2"
private val ExpectFail = setOf<FailingTest>()
private val ExpectFail = setOf<FailingTest>(
// TODO(https://github.com/awslabs/smithy/pull/1049): Remove this once the test case in Smithy is fixed
FailingTest(RestJson, "RestJsonInputAndOutputWithQuotedStringHeaders", Action.Response),
// TODO(https://github.com/awslabs/smithy/pull/1042): Remove this once the test case in Smithy is fixed
FailingTest(RestJson, "RestJsonInputUnionWithUnitMember", Action.Request),
FailingTest("${RestJson}Extras", "RestJsonInputUnionWithUnitMember", Action.Request),
)
private val RunOnly: Set<String>? = null

// These tests are not even attempted to be generated, either because they will not compile
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.smithy.protocols

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.pattern.UriPattern
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
Expand All @@ -25,7 +26,6 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredData
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.orNull

sealed class AwsJsonVersion {
abstract val value: String
Expand Down Expand Up @@ -72,19 +72,18 @@ class AwsJsonHttpBindingResolver(
.uri(UriPattern.parse("/"))
.build()

private fun bindings(shape: ToShapeId?) =
shape?.let { model.expectShape(it.toShapeId()) }?.members()
?.map { HttpBindingDescriptor(it, HttpLocation.DOCUMENT, "document") }
?.toList()
?: emptyList()
private fun bindings(shape: ToShapeId) =
shape.let { model.expectShape(it.toShapeId()) }.members()
.map { HttpBindingDescriptor(it, HttpLocation.DOCUMENT, "document") }
.toList()

override fun httpTrait(operationShape: OperationShape): HttpTrait = httpTrait

override fun requestBindings(operationShape: OperationShape): List<HttpBindingDescriptor> =
bindings(operationShape.input.orNull())
bindings(operationShape.inputShape)

override fun responseBindings(operationShape: OperationShape): List<HttpBindingDescriptor> =
bindings(operationShape.output.orNull())
bindings(operationShape.outputShape)

override fun errorResponseBindings(errorShape: ToShapeId): List<HttpBindingDescriptor> =
bindings(errorShape)
Expand All @@ -103,7 +102,7 @@ class AwsJsonSerializerGenerator(
private val codegenContext: CodegenContext,
httpBindingResolver: HttpBindingResolver,
private val jsonSerializerGenerator: JsonSerializerGenerator =
JsonSerializerGenerator(codegenContext, httpBindingResolver)
JsonSerializerGenerator(codegenContext, httpBindingResolver, ::awsJsonFieldName)
) : StructuredDataSerializerGenerator by jsonSerializerGenerator {
private val runtimeConfig = codegenContext.runtimeConfig
private val codegenScope = arrayOf(
Expand Down Expand Up @@ -153,7 +152,7 @@ class AwsJson(
listOf("x-amz-target" to "${codegenContext.serviceShape.id.name}.${operationShape.id.name}")

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator =
JsonParserGenerator(codegenContext, httpBindingResolver)
JsonParserGenerator(codegenContext, httpBindingResolver, ::awsJsonFieldName)

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
AwsJsonSerializerGenerator(codegenContext, httpBindingResolver)
Expand Down Expand Up @@ -183,3 +182,7 @@ class AwsJson(
)
}
}

private fun awsJsonFieldName(member: MemberShape): String {
return member.memberName
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package software.amazon.smithy.rust.codegen.smithy.protocols

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.JsonNameTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustModule
Expand All @@ -19,6 +21,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGene
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.util.getTrait

class RestJsonFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> {
override fun protocol(codegenContext: CodegenContext): Protocol = RestJson(codegenContext)
Expand Down Expand Up @@ -61,13 +64,11 @@ class RestJson(private val codegenContext: CodegenContext) : Protocol {

override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
return JsonParserGenerator(codegenContext, httpBindingResolver)
}
override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator =
JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator {
return JsonSerializerGenerator(codegenContext, httpBindingResolver)
}
override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
JsonSerializerGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)

override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType =
RuntimeType.forInlineFun("parse_http_generic_error", jsonDeserModule) { writer ->
Expand All @@ -94,3 +95,7 @@ class RestJson(private val codegenContext: CodegenContext) : Protocol {
)
}
}

fun restJsonFieldName(member: MemberShape): String {
return member.getTrait<JsonNameTrait>()?.value ?: member.memberName
}
Loading