Skip to content

Commit

Permalink
Improve broken protocol test generation (#3726)
Browse files Browse the repository at this point in the history
We currently "hotfix" a broken protocol test in-memory, but there's no
mechanism that alerts us when the broken protocol test has been fixed
upstream when updating our Smithy version. This commit introduces such a
mechanism by generating both the original and the fixed test, with a
`#[should_panic]` attribute on the former, so that the test fails when
all its assertions succeed.

With this change, in general this approach of fixing tests in-memory
should now be used over adding the broken test to `expectFail` and
adding the fixed test to a `<protocol>-extras.smithy` Smithy model,
which is substantially more effort.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
david-perez authored Jul 1, 2024
1 parent 24a011b commit 7299cdd
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.BrokenTest
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.FailingTest
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_10
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCaseKind
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
Expand Down Expand Up @@ -70,9 +70,9 @@ class ClientProtocolTestGenerator(
private val ExpectFail =
setOf<FailingTest>(
// Failing because we don't serialize default values if they match the default.
FailingTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultsValuesWhenMissingInResponse", TestCaseKind.Request),
FailingTest(AWS_JSON_10, "AwsJson10ClientUsesExplicitlyProvidedMemberValuesOverDefaults", TestCaseKind.Request),
FailingTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultValuesInInput", TestCaseKind.Request),
FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultsValuesWhenMissingInResponse"),
FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientUsesExplicitlyProvidedMemberValuesOverDefaults"),
FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultValuesInInput"),
)
}

Expand All @@ -84,6 +84,8 @@ class ClientProtocolTestGenerator(
get() = emptySet()
override val disabledTests: Set<String>
get() = emptySet()
override val brokenTests: Set<BrokenTest>
get() = emptySet()

override val logger: Logger = Logger.getLogger(javaClass.name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.orNull
Expand All @@ -51,9 +52,17 @@ abstract class ProtocolTestGenerator {
/**
* We expect these tests to fail due to shortcomings in our implementation.
* They will _fail_ if they pass, so we will discover and remove them if we fix them by accident.
**/
*/
abstract val expectFail: Set<FailingTest>

/**
* We expect these tests to fail because their definitions are broken.
* We map from a failing test to a "hotfix" function that can mutate the test in-memory and return a fixed version of it.
* The tests will _fail_ if they pass, so we will discover and remove the hotfix if we're updating to a newer
* version of Smithy where the test was fixed upstream.
*/
abstract val brokenTests: Set<BrokenTest>

/** Only generate these tests; useful to temporarily set and shorten development cycles */
abstract val runOnly: Set<String>

Expand All @@ -63,18 +72,23 @@ abstract class ProtocolTestGenerator {
*/
abstract val disabledTests: Set<String>

private val serviceShapeId: ShapeId
get() = codegenContext.serviceShape.id

/** The Rust module in which we should generate the protocol tests for [operationShape]. */
private fun protocolTestsModule(): RustModule.LeafModule {
val operationName = codegenContext.symbolProvider.toSymbol(operationShape).name
val testModuleName = "${operationName.toSnakeCase()}_test"
val additionalAttributes =
listOf(Attribute(allow("unreachable_code", "unused_variables")))
val additionalAttributes = listOf(Attribute(allow("unreachable_code", "unused_variables")))
return RustModule.inlineTests(testModuleName, additionalAttributes = additionalAttributes)
}

/** The entry point to render the protocol tests, invoked by the code generators. */
fun render(writer: RustWriter) {
val allTests = allMatchingTestCases().fixBroken()
val allTests =
allMatchingTestCases().flatMap {
fixBrokenTestCase(it)
}
if (allTests.isEmpty()) {
return
}
Expand All @@ -84,15 +98,65 @@ abstract class ProtocolTestGenerator {
}
}

/** Implementors should describe how to render the test cases. **/
abstract fun RustWriter.renderAllTestCases(allTests: List<TestCase>)

/**
* This function applies a "fix function" to each broken test before we synthesize it.
* Broken tests are those whose definitions in the `awslabs/smithy` repository are wrong.
* We try to contribute fixes upstream to pare down this function to the identity function.
* This function applies a "hotfix function" to a broken test case before we synthesize it.
* Broken tests are those whose definitions in the `smithy-lang/smithy` repository are wrong.
* We try to contribute fixes upstream to pare down the list of broken tests.
* If the test is broken, we synthesize it in two versions: the original broken test with a `#[should_panic]`
* attribute, so get alerted if the test now passes, and the fixed version, which should pass.
*/
open fun List<TestCase>.fixBroken(): List<TestCase> = this
private fun fixBrokenTestCase(it: TestCase): List<TestCase> =
if (!it.isBroken()) {
listOf(it)
} else {
assert(it.expectFail())

val brokenTest = it.findInBroken()!!
var fixed = brokenTest.fixIt(it)

val intro = "The hotfix function for broken test case ${it.kind} ${it.id}"
val moreInfo =
"""This test case was identified to be broken in at least these Smithy versions: [${brokenTest.inAtLeast.joinToString()}].
|We are tracking things here: [${brokenTest.trackedIn.joinToString()}].
""".trimMargin()

// Something must change...
if (it == fixed) {
PANIC(
"""$intro did not make any modifications. It is likely that the test case was
|fixed upstream, and you're now updating the Smithy version; in this case, remove the hotfix
|function, as the test is no longer broken.
|$moreInfo
""".trimMargin(),
)
}

// ... but the hotfix function is not allowed to change the test case kind...
if (it.kind != fixed.kind) {
PANIC(
"""$intro changed the test case kind. This is not allowed.
|$moreInfo
""".trimMargin(),
)
}

// ... nor its id.
if (it.id != fixed.id) {
PANIC(
"""$intro changed the test case id. This is not allowed.
|$moreInfo
""".trimMargin(),
)
}

// The latter is because we're going to generate the fixed version with an identifiable suffix.
fixed = fixed.suffixIdWith("_hotfixed")

listOf(it, fixed)
}

/** Implementors should describe how to render the test cases. **/
abstract fun RustWriter.renderAllTestCases(allTests: List<TestCase>)

/** Filter out test cases that are disabled or don't match the service protocol. */
private fun List<TestCase>.filterMatching(): List<TestCase> =
Expand All @@ -103,11 +167,25 @@ abstract class ProtocolTestGenerator {
this.filter { testCase -> runOnly.contains(testCase.id) }
}

/** Do we expect this [testCase] to fail? */
private fun expectFail(testCase: TestCase): Boolean =
expectFail.find {
it.id == testCase.id && it.kind == testCase.kind && it.service == codegenContext.serviceShape.id.toString()
} != null
private fun TestCase.toFailingTest(): FailingTest =
when (this) {
is TestCase.MalformedRequestTest -> FailingTest.MalformedRequestTest(serviceShapeId.toString(), this.id)
is TestCase.RequestTest -> FailingTest.RequestTest(serviceShapeId.toString(), this.id)
is TestCase.ResponseTest -> FailingTest.ResponseTest(serviceShapeId.toString(), this.id)
}

/** Do we expect this test case to fail? */
private fun TestCase.expectFail(): Boolean = this.isBroken() || expectFail.contains(this.toFailingTest())

/** Is this test case broken? */
private fun TestCase.isBroken(): Boolean = this.findInBroken() != null

private fun TestCase.findInBroken(): BrokenTest? =
brokenTests.find { brokenTest ->
(this is TestCase.RequestTest && brokenTest is BrokenTest.RequestTest && this.id == brokenTest.id) ||
(this is TestCase.ResponseTest && brokenTest is BrokenTest.ResponseTest && this.id == brokenTest.id) ||
(this is TestCase.MalformedRequestTest && brokenTest is BrokenTest.MalformedRequestTest && this.id == brokenTest.id)
}

fun requestTestCases(): List<TestCase> {
val requestTests =
Expand Down Expand Up @@ -160,6 +238,7 @@ abstract class ProtocolTestGenerator {
block: Writable,
) {
if (testCase.documentation != null) {
testModuleWriter.rust("")
testModuleWriter.docs(testCase.documentation!!, templating = false)
}
testModuleWriter.docs("Test ID: ${testCase.id}")
Expand All @@ -171,7 +250,7 @@ abstract class ProtocolTestGenerator {
Attribute.TokioTest.render(testModuleWriter)
Attribute.TracedTest.render(testModuleWriter)

if (expectFail(testCase)) {
if (testCase.expectFail()) {
shouldPanic().render(testModuleWriter)
}
val fnNameSuffix =
Expand Down Expand Up @@ -281,6 +360,51 @@ abstract class ProtocolTestGenerator {
}
}

sealed class BrokenTest(
open val serviceShapeId: String,
open val id: String,
/** A non-exhaustive set of Smithy versions where the test was found to be broken. */
open val inAtLeast: Set<String>,
/**
* GitHub URLs related to the test brokenness, like a GitHub issue in Smithy where we reported the test was broken,
* or a PR where we fixed it.
**/
open val trackedIn: Set<String>,
) {
data class RequestTest(
override val serviceShapeId: String,
override val id: String,
override val inAtLeast: Set<String>,
override val trackedIn: Set<String>,
val howToFixItFn: (TestCase.RequestTest) -> TestCase.RequestTest,
) : BrokenTest(serviceShapeId, id, inAtLeast, trackedIn)

data class ResponseTest(
override val serviceShapeId: String,
override val id: String,
override val inAtLeast: Set<String>,
override val trackedIn: Set<String>,
val howToFixItFn: (TestCase.ResponseTest) -> TestCase.ResponseTest,
) : BrokenTest(serviceShapeId, id, inAtLeast, trackedIn)

data class MalformedRequestTest(
override val serviceShapeId: String,
override val id: String,
override val inAtLeast: Set<String>,
override val trackedIn: Set<String>,
val howToFixItFn: (TestCase.MalformedRequestTest) -> TestCase.MalformedRequestTest,
) : BrokenTest(serviceShapeId, id, inAtLeast, trackedIn)

fun fixIt(testToFix: TestCase): TestCase {
check(testToFix.id == this.id)
return when (this) {
is MalformedRequestTest -> howToFixItFn(testToFix as TestCase.MalformedRequestTest)
is RequestTest -> howToFixItFn(testToFix as TestCase.RequestTest)
is ResponseTest -> howToFixItFn(testToFix as TestCase.ResponseTest)
}
}
}

/**
* Service shape IDs in common protocol test suites defined upstream.
*/
Expand All @@ -291,7 +415,16 @@ object ServiceShapeId {
const val REST_JSON_VALIDATION = "aws.protocoltests.restjson.validation#RestJsonValidation"
}

data class FailingTest(val service: String, val id: String, val kind: TestCaseKind)
sealed class FailingTest(open val serviceShapeId: String, open val id: String) {
data class RequestTest(override val serviceShapeId: String, override val id: String) :
FailingTest(serviceShapeId, id)

data class ResponseTest(override val serviceShapeId: String, override val id: String) :
FailingTest(serviceShapeId, id)

data class MalformedRequestTest(override val serviceShapeId: String, override val id: String) :
FailingTest(serviceShapeId, id)
}

sealed class TestCaseKind {
data object Request : TestCaseKind()
Expand All @@ -302,11 +435,60 @@ sealed class TestCaseKind {
}

sealed class TestCase {
data class RequestTest(val testCase: HttpRequestTestCase) : TestCase()
/*
* The properties of these data classes don't implement `equals()` usefully in Smithy, so we delegate to `equals()`
* of their `Node` representations.
*/

data class RequestTest(val testCase: HttpRequestTestCase) : TestCase() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is RequestTest) return false
return testCase.toNode().equals(other.testCase.toNode())
}

override fun hashCode(): Int = testCase.hashCode()
}

data class ResponseTest(val testCase: HttpResponseTestCase, val targetShape: StructureShape) : TestCase() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is ResponseTest) return false
return testCase.toNode().equals(other.testCase.toNode())
}

override fun hashCode(): Int = testCase.hashCode()
}

data class MalformedRequestTest(val testCase: HttpMalformedRequestTestCase) : TestCase() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is MalformedRequestTest) return false
return this.protocol == other.protocol && this.id == other.id && this.documentation == other.documentation &&
this.testCase.request.toNode()
.equals(other.testCase.request.toNode()) &&
this.testCase.response.toNode()
.equals(other.testCase.response.toNode())
}

override fun hashCode(): Int = testCase.hashCode()
}

fun suffixIdWith(suffix: String): TestCase =
when (this) {
is RequestTest -> RequestTest(this.testCase.suffixIdWith(suffix))
is MalformedRequestTest -> MalformedRequestTest(this.testCase.suffixIdWith(suffix))
is ResponseTest -> ResponseTest(this.testCase.suffixIdWith(suffix), this.targetShape)
}

private fun HttpRequestTestCase.suffixIdWith(suffix: String): HttpRequestTestCase =
this.toBuilder().id(this.id + suffix).build()

data class ResponseTest(val testCase: HttpResponseTestCase, val targetShape: StructureShape) : TestCase()
private fun HttpResponseTestCase.suffixIdWith(suffix: String): HttpResponseTestCase =
this.toBuilder().id(this.id + suffix).build()

data class MalformedRequestTest(val testCase: HttpMalformedRequestTestCase) : TestCase()
private fun HttpMalformedRequestTestCase.suffixIdWith(suffix: String): HttpMalformedRequestTestCase =
this.toBuilder().id(this.id + suffix).build()

/*
* `HttpRequestTestCase` and `HttpResponseTestCase` both implement `HttpMessageTestCase`, but
Expand Down
Loading

0 comments on commit 7299cdd

Please sign in to comment.