Skip to content

Commit

Permalink
Add RPCv2 support
Browse files Browse the repository at this point in the history
  • Loading branch information
david-perez committed Jun 4, 2024
1 parent 9c1ae5a commit b9016ae
Show file tree
Hide file tree
Showing 68 changed files with 4,022 additions and 120 deletions.
3 changes: 1 addition & 2 deletions aws/sdk-adhoc-test/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ java {
}

val smithyVersion: String by project
val defaultRustDocFlags: String by project
val properties = PropertyRetriever(rootProject, project)

val pluginName = "rust-client-codegen"
Expand Down Expand Up @@ -78,7 +77,7 @@ tasks["smithyBuild"].dependsOn("generateSmithyBuild")
tasks["assemble"].finalizedBy("generateCargoWorkspace")

project.registerModifyMtimeTask()
project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile, defaultRustDocFlags)
project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile)

tasks["test"].finalizedBy(cargoCommands(properties).map { it.toString })

Expand Down
3 changes: 1 addition & 2 deletions aws/sdk/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ configure<software.amazon.smithy.gradle.SmithyExtension> {
}

val smithyVersion: String by project
val defaultRustDocFlags: String by project
val properties = PropertyRetriever(rootProject, project)

val crateHasherToolPath = rootProject.projectDir.resolve("tools/ci-build/crate-hasher")
Expand Down Expand Up @@ -442,7 +441,7 @@ tasks["assemble"].apply {
outputs.upToDateWhen { false }
}

project.registerCargoCommandsTasks(outputDir.asFile, defaultRustDocFlags)
project.registerCargoCommandsTasks(outputDir.asFile)
project.registerGenerateCargoConfigTomlTask(outputDir.asFile)

//The task name "test" is already registered by one of our plugins
Expand Down
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ allprojects {
val allowLocalDeps: String by project
repositories {
if (allowLocalDeps.toBoolean()) {
mavenLocal()
mavenLocal()
}
mavenCentral()
google()
Expand Down
59 changes: 48 additions & 11 deletions buildSrc/src/main/kotlin/CodegenTestCommon.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,49 @@ fun generateImports(imports: List<String>): String =
"\"imports\": [${imports.map { "\"$it\"" }.joinToString(", ")}],"
}

fun toRustCrateName(input: String): String {
val rustKeywords = setOf(
// Strict Keywords.
"as", "break", "const", "continue", "crate", "else", "enum", "extern",
"false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod",
"move", "mut", "pub", "ref", "return", "Self", "self", "static", "struct",
"super", "trait", "true", "type", "unsafe", "use", "where", "while",

// Weak Keywords.
"dyn", "async", "await", "try",

// Reserved for Future Use.
"abstract", "become", "box", "do", "final", "macro", "override", "priv",
"typeof", "unsized", "virtual", "yield",

// Primitive Types.
"bool", "char", "i8", "i16", "i32", "i64", "i128", "isize",
"u8", "u16", "u32", "u64", "u128", "usize", "f32", "f64", "str",

// Additional significant identifiers.
"proc_macro"
)

// Then within your function, you could include a check against this set
if (input.isBlank()) {
throw IllegalArgumentException("Rust crate name cannot be empty")
}
val lowerCased = input.lowercase()
// Replace any sequence of characters that are not lowercase letters, numbers, or underscores with a single underscore.
val sanitized = lowerCased.replace(Regex("[^a-z0-9_]+"), "_")
// Trim leading or trailing underscores.
val trimmed = sanitized.trim('_')
// Check if the resulting string is empty, purely numeric, or a reserved name
val finalName = when {
trimmed.isEmpty() -> throw IllegalArgumentException("Rust crate name after sanitizing cannot be empty.")
trimmed.matches(Regex("\\d+")) -> "n$trimmed" // Prepend 'n' if the name is purely numeric.
trimmed in rustKeywords -> "${trimmed}_" // Append an underscore if the name is reserved.
else -> trimmed
}
return finalName
}


private fun generateSmithyBuild(
projectDir: String,
pluginName: String,
Expand All @@ -48,7 +91,7 @@ private fun generateSmithyBuild(
${it.extraCodegenConfig ?: ""}
},
"service": "${it.service}",
"module": "${it.module}",
"module": "${toRustCrateName(it.module)}",
"moduleVersion": "0.0.1",
"moduleDescription": "test",
"moduleAuthors": ["[email protected]"]
Expand Down Expand Up @@ -205,13 +248,15 @@ fun Project.registerGenerateCargoWorkspaceTask(
fun Project.registerGenerateCargoConfigTomlTask(outputDir: File) {
this.tasks.register("generateCargoConfigToml") {
description = "generate `.cargo/config.toml`"
// TODO(https://github.com/smithy-lang/smithy-rs/issues/1068): Once doc normalization
// is completed, warnings can be prohibited in rustdoc by setting `rustdocflags` to `-D warnings`.
doFirst {
outputDir.resolve(".cargo").mkdirs()
outputDir.resolve(".cargo/config.toml")
.writeText(
"""
[build]
rustflags = ["--deny", "warnings"]
rustflags = ["--deny", "warnings", "--cfg", "aws_sdk_unstable"]
""".trimIndent(),
)
}
Expand Down Expand Up @@ -255,10 +300,7 @@ fun Project.registerModifyMtimeTask() {
}
}

fun Project.registerCargoCommandsTasks(
outputDir: File,
defaultRustDocFlags: String,
) {
fun Project.registerCargoCommandsTasks(outputDir: File) {
val dependentTasks =
listOfNotNull(
"assemble",
Expand All @@ -269,29 +311,24 @@ fun Project.registerCargoCommandsTasks(
this.tasks.register<Exec>(Cargo.CHECK.toString) {
dependsOn(dependentTasks)
workingDir(outputDir)
environment("RUSTFLAGS", "--cfg aws_sdk_unstable")
commandLine("cargo", "check", "--lib", "--tests", "--benches", "--all-features")
}

this.tasks.register<Exec>(Cargo.TEST.toString) {
dependsOn(dependentTasks)
workingDir(outputDir)
environment("RUSTFLAGS", "--cfg aws_sdk_unstable")
commandLine("cargo", "test", "--all-features", "--no-fail-fast")
}

this.tasks.register<Exec>(Cargo.DOCS.toString) {
dependsOn(dependentTasks)
workingDir(outputDir)
environment("RUSTDOCFLAGS", defaultRustDocFlags)
environment("RUSTFLAGS", "--cfg aws_sdk_unstable")
commandLine("cargo", "doc", "--no-deps", "--document-private-items")
}

this.tasks.register<Exec>(Cargo.CLIPPY.toString) {
dependsOn(dependentTasks)
workingDir(outputDir)
environment("RUSTFLAGS", "--cfg aws_sdk_unstable")
commandLine("cargo", "clippy")
}
}
8 changes: 6 additions & 2 deletions codegen-client-test/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ plugins {
}

val smithyVersion: String by project
val defaultRustDocFlags: String by project
val properties = PropertyRetriever(rootProject, project)
fun getSmithyRuntimeMode(): String = properties.get("smithy.runtime.mode") ?: "orchestrator"

Expand Down Expand Up @@ -112,6 +111,11 @@ val allCodegenTests = listOf(
"pokemon-service-awsjson-client",
dependsOn = listOf("pokemon-awsjson.smithy", "pokemon-common.smithy"),
),
ClientTest(
"com.amazonaws.simple#RpcV2Service",
"rpcv2-pokemon-client",
dependsOn = listOf("rpcv2.smithy")
),
ClientTest("aws.protocoltests.misc#QueryCompatService", "query-compat-test", dependsOn = listOf("aws-json-query-compat.smithy")),
).map(ClientTest::toCodegenTest)

Expand All @@ -125,7 +129,7 @@ tasks["smithyBuild"].dependsOn("generateSmithyBuild")
tasks["assemble"].finalizedBy("generateCargoWorkspace")

project.registerModifyMtimeTask()
project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile, defaultRustDocFlags)
project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile)

tasks["test"].finalizedBy(cargoCommands(properties).map { it.toString })

Expand Down
3 changes: 2 additions & 1 deletion codegen-client/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ dependencies {
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-waiters:$smithyVersion")
implementation("software.amazon.smithy:smithy-rules-engine:$smithyVersion")
implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion")

// `smithy.framework#ValidationException` is defined here, which is used in event stream
// marshalling/unmarshalling tests.
// marshalling/unmarshalling tests.
testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
Expand Down Expand Up @@ -173,13 +175,19 @@ class DefaultProtocolTestGenerator(
}
testModuleWriter.write("Test ID: ${testCase.id}")
testModuleWriter.newlinePrefix = ""

Attribute.TokioTest.render(testModuleWriter)
val action =
when (testCase) {
is HttpResponseTestCase -> Action.Response
is HttpRequestTestCase -> Action.Request
else -> throw CodegenException("unknown test case type")
}
Attribute.TracedTest.render(testModuleWriter)
// The `#[traced_test]` macro desugars to using `tracing`, so we need to depend on the latter explicitly in
// case the code rendered by the test does not make use of `tracing` at all.
val tracingDevDependency = testDependenciesOnly { addDependency(CargoDependency.Tracing.toDevDependency()) }
testModuleWriter.rustTemplate("#{TracingDevDependency:W}", "TracingDevDependency" to tracingDevDependency)

val action = when (testCase) {
is HttpResponseTestCase -> Action.Response
is HttpRequestTestCase -> Action.Request
else -> throw CodegenException("unknown test case type")
}
if (expectFail(testCase)) {
testModuleWriter.writeWithNoFormatting("#[should_panic]")
}
Expand Down Expand Up @@ -415,8 +423,8 @@ class DefaultProtocolTestGenerator(
if (body == "") {
rustWriter.rustTemplate(
"""
// No body
#{AssertEq}(::std::str::from_utf8(body).unwrap(), "");
// No body.
#{AssertEq}(&body, &bytes::Bytes::new());
""",
*codegenScope,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
Expand All @@ -28,20 +29,21 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2
import software.amazon.smithy.rust.codegen.core.util.hasTrait

class ClientProtocolLoader(supportedProtocols: ProtocolMap<OperationGenerator, ClientCodegenContext>) :
ProtocolLoader<OperationGenerator, ClientCodegenContext>(supportedProtocols) {
companion object {
val DefaultProtocols =
mapOf(
AwsJson1_0Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json10),
AwsJson1_1Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json11),
AwsQueryTrait.ID to ClientAwsQueryFactory(),
Ec2QueryTrait.ID to ClientEc2QueryFactory(),
RestJson1Trait.ID to ClientRestJsonFactory(),
RestXmlTrait.ID to ClientRestXmlFactory(),
)
val DefaultProtocols = mapOf(
AwsJson1_0Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json10),
AwsJson1_1Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json11),
AwsQueryTrait.ID to ClientAwsQueryFactory(),
Ec2QueryTrait.ID to ClientEc2QueryFactory(),
RestJson1Trait.ID to ClientRestJsonFactory(),
RestXmlTrait.ID to ClientRestXmlFactory(),
Rpcv2CborTrait.ID to ClientRpcV2CborFactory(),
)
val Default = ClientProtocolLoader(DefaultProtocols)
}
}
Expand Down Expand Up @@ -117,3 +119,12 @@ class ClientRestXmlFactory(

override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT
}

class ClientRpcV2CborFactory : ProtocolGeneratorFactory<OperationGenerator, ClientCodegenContext> {
override fun protocol(codegenContext: ClientCodegenContext): Protocol = RpcV2(codegenContext)

override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): OperationGenerator =
OperationGenerator(codegenContext, protocol(codegenContext))

override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT
}
1 change: 1 addition & 0 deletions codegen-core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies {
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-waiters:$smithyVersion")
implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion")
}

fun gitCommitHash(): String {
Expand Down
Loading

0 comments on commit b9016ae

Please sign in to comment.