Skip to content

Commit

Permalink
fix: codegen of types that derive PartialEq
Browse files Browse the repository at this point in the history
fix: codegen issues caused by MSRV upgrade
  • Loading branch information
Velfi authored and Nugine committed Jan 20, 2023
1 parent 4db0023 commit ab2c1d4
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 15 deletions.
2 changes: 1 addition & 1 deletion aws/rust-runtime/aws-sigv4/src/http_request/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::time::Duration;
pub type SigningParams<'a> = crate::SigningParams<'a, SigningSettings>;

/// HTTP-specific signing settings
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Eq)]
#[non_exhaustive]
pub struct SigningSettings {
/// Specifies how to encode the request URL when signing. Some services do not decode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ internal class EndpointParamsGenerator(private val parameters: Parameters) {
private fun generateEndpointsStruct(writer: RustWriter) {
// Ensure that fields can be added in the future
Attribute.NonExhaustive.render(writer)
// Required in case we ever need to add a field that's not Eq
Attribute.AllowClippyDerivePartialEqWithoutEq.render(writer)
// Automatically implement standard Rust functionality
Attribute(derive(RuntimeType.Debug, RuntimeType.PartialEq, RuntimeType.Clone)).render(writer)
// Generate the struct block:
Expand Down Expand Up @@ -235,6 +237,8 @@ internal class EndpointParamsGenerator(private val parameters: Parameters) {

private fun generateEndpointParamsBuilder(rustWriter: RustWriter) {
rustWriter.docs("Builder for [`Params`]")
// Required in case we ever need to add a param that's not Eq
Attribute.AllowClippyDerivePartialEqWithoutEq.render(rustWriter)
Attribute(derive(RuntimeType.Debug, RuntimeType.Default, RuntimeType.PartialEq, RuntimeType.Clone)).render(rustWriter)
rustWriter.rustBlock("pub struct ParamsBuilder") {
parameters.toList().forEach { parameter ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,14 @@ class ExpressionGenerator(
getAttr.path.toList().forEach { part ->
when (part) {
is GetAttr.Part.Key -> rust(".${part.key().rustName()}()")
is GetAttr.Part.Index -> rust(".get(${part.index()}).cloned()") // we end up with Option<&&T>, we need to get to Option<&T>
is GetAttr.Part.Index -> {
if (part.index() == 0) {
// In this case, `.first()` is more idiomatic and `.get(0)` triggers lint warnings
rust(".first().cloned()")
} else {
rust(".get(${part.index()}).cloned()") // we end up with Option<&&T>, we need to get to Option<&T>
}
}
}
}
if (ownership == Ownership.Owned && getAttr.type() != Type.bool()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,16 @@ fun RustType.isCopy(): Boolean = when (this) {
else -> false
}

/** Returns true if the type implements Eq */
fun RustType.isEq(): Boolean = when (this) {
is RustType.Integer -> true
is RustType.Bool -> true
is RustType.String -> true
is RustType.Unit -> true
is RustType.Container -> this.member.isEq()
else -> false
}

enum class Visibility {
PRIVATE, PUBCRATE, PUBLIC;

Expand Down Expand Up @@ -454,6 +464,7 @@ class Attribute(val inner: Writable) {
val AllowClippyUnnecessaryWraps = Attribute(allow("clippy::unnecessary_wraps"))
val AllowClippyUselessConversion = Attribute(allow("clippy::useless_conversion"))
val AllowClippyUnnecessaryLazyEvaluations = Attribute(allow("clippy::unnecessary_lazy_evaluations"))
val AllowClippyDerivePartialEqWithoutEq = Attribute(allow("clippy::derive_partial_eq_without_eq"))
val AllowDeadCode = Attribute(allow("dead_code"))
val AllowDeprecated = Attribute(allow("deprecated"))
val AllowIrrefutableLetPatterns = Attribute(allow("irrefutable_let_patterns"))
Expand Down Expand Up @@ -546,3 +557,10 @@ class Attribute(val inner: Writable) {
}
}
}

/** Render all attributes in this list, one after another */
fun Collection<Attribute>.render(writer: RustWriter) {
for (attr in this) {
attr.render(writer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,15 @@ fun containerDefaultMetadata(
model: Model,
additionalAttributes: List<Attribute> = emptyList(),
): RustMetadata {
val defaultDerives = setOf(RuntimeType.Debug, RuntimeType.PartialEq, RuntimeType.Clone)
// A list of `allow` attributes to ignore linter warnings. Each entry in the list must be
// accompanied by a reason.
val allowLints = setOf(
// Required because service team could add non-Eq types at a later date. This
// means we can only ever derive PartialEq.
Attribute.AllowClippyDerivePartialEqWithoutEq,
)

val derives = mutableSetOf(RuntimeType.Debug, RuntimeType.PartialEq, RuntimeType.Clone)

val isSensitive = shape.hasTrait<SensitiveTrait>() ||
// Checking the shape's direct members for the sensitive trait should suffice.
Expand All @@ -101,22 +109,21 @@ fun containerDefaultMetadata(
// shape; any sensitive descendant should still be printed as redacted.
shape.members().any { it.getMemberTrait(model, SensitiveTrait::class.java).isPresent }

val setOfDerives = if (isSensitive) {
defaultDerives - RuntimeType.Debug
} else {
defaultDerives
if (isSensitive) {
derives.remove(RuntimeType.Debug)
}

return RustMetadata(
setOfDerives,
additionalAttributes,
derives,
additionalAttributes + allowLints,
Visibility.PUBLIC,
)
}

/**
* The base metadata supports a set of attributes that are used by generators to decorate code.
*
* By default we apply `#[non_exhaustive]` in [additionalAttributes] only to client structures since breaking model
* By default, we apply `#[non_exhaustive]` in [additionalAttributes] only to client structures since breaking model
* changes are fine when generating server code.
*/
class BaseSymbolMetadataProvider(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ class BuilderGenerator(
private val members: List<MemberShape> = shape.allMembers.values.toList()
private val structureSymbol = symbolProvider.toSymbol(shape)
private val builderSymbol = shape.builderSymbol(symbolProvider)
private val baseDerives = structureSymbol.expectRustMetadata().derives
private val metadata = structureSymbol.expectRustMetadata()

// Filter out any derive that isn't Debug, PartialEq, or Clone. Then add a Default derive
private val builderDerives = baseDerives.filter { it == RuntimeType.Debug || it == RuntimeType.PartialEq || it == RuntimeType.Clone } + RuntimeType.Default
// Filter out any derive that isn't Debug, PartialEq, Eq, or Clone. Then add a Default derive
private val builderDerives = metadata.derives.filter { it == RuntimeType.Debug || it == RuntimeType.PartialEq || it == RuntimeType.Eq || it == RuntimeType.Clone } + RuntimeType.Default
private val builderName = "Builder"

fun render(writer: RustWriter) {
Expand Down Expand Up @@ -207,6 +207,7 @@ class BuilderGenerator(

private fun renderBuilder(writer: RustWriter) {
writer.docs("A builder for #D.", structureSymbol)
metadata.additionalAttributes.render(writer)
Attribute(derive(builderDerives)).render(writer)
writer.rustBlock("pub struct $builderName") {
for (member in members) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -650,12 +650,13 @@ class HttpBindingGenerator(
let $safeName = $formatted;
if !$safeName.is_empty() {
let header_value = $safeName;
let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| {
let header_value: #{HeaderValue} = header_value.parse().map_err(|err| {
#{invalid_field_error:W}
})?;
builder = builder.header("$headerName", header_value);
}
""",
"HeaderValue" to RuntimeType.Http.resolve("HeaderValue"),
"invalid_field_error" to renderErrorMessage("header_value"),
)
}
Expand Down Expand Up @@ -698,13 +699,14 @@ class HttpBindingGenerator(
isMultiValuedHeader = false,
)
};
let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| {
let header_value: #{HeaderValue} = header_value.parse().map_err(|err| {
#{invalid_header_value:W}
})?;
builder = builder.header(header_name, header_value);
}
""",
"HeaderValue" to RuntimeType.Http.resolve("HeaderValue"),
"invalid_header_name" to OperationBuildError(runtimeConfig).invalidField(memberName) {
rust("""format!("`{k}` cannot be used as a header name: {err}")""")
},
Expand Down
1 change: 1 addition & 0 deletions rust-runtime/aws-smithy-eventstream/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ mod value {
const TYPE_UUID: u8 = 9;

/// Event Stream frame header value.
#[allow(clippy::derive_partial_eq_without_eq)]
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq)]
pub enum HeaderValue {
Expand Down

0 comments on commit ab2c1d4

Please sign in to comment.