Skip to content

Commit

Permalink
Auto merge of #98376 - nnethercote:improve-derive-PartialEq, r=petroc…
Browse files Browse the repository at this point in the history
…henkov

Improve some deriving code and add a test

The `.stdout` test is particularly useful.

r? `@petrochenkov`
  • Loading branch information
bors committed Jun 29, 2022
2 parents 2953edc + 02d2cdf commit 126e3df
Show file tree
Hide file tree
Showing 7 changed files with 1,251 additions and 131 deletions.
48 changes: 21 additions & 27 deletions compiler/rustc_builtin_macros/src/deriving/clone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use crate::deriving::generic::*;
use crate::deriving::path_std;

use rustc_ast::ptr::P;
use rustc_ast::{self as ast, Expr, GenericArg, Generics, ItemKind, MetaItem, VariantData};
use rustc_ast::{self as ast, Expr, Generics, ItemKind, MetaItem, VariantData};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::symbol::{kw, sym, Ident, Symbol};
use rustc_span::symbol::{kw, sym, Ident};
use rustc_span::Span;

pub fn expand_deriving_clone(
Expand Down Expand Up @@ -107,44 +107,38 @@ fn cs_clone_shallow(
substr: &Substructure<'_>,
is_union: bool,
) -> P<Expr> {
fn assert_ty_bounds(
cx: &mut ExtCtxt<'_>,
stmts: &mut Vec<ast::Stmt>,
ty: P<ast::Ty>,
span: Span,
helper_name: &str,
) {
// Generate statement `let _: helper_name<ty>;`,
// set the expn ID so we can use the unstable struct.
let span = cx.with_def_site_ctxt(span);
let assert_path = cx.path_all(
span,
true,
cx.std_path(&[sym::clone, Symbol::intern(helper_name)]),
vec![GenericArg::Type(ty)],
);
stmts.push(cx.stmt_let_type_only(span, cx.ty_path(assert_path)));
}
fn process_variant(cx: &mut ExtCtxt<'_>, stmts: &mut Vec<ast::Stmt>, variant: &VariantData) {
let mut stmts = Vec::new();
let mut process_variant = |variant: &VariantData| {
for field in variant.fields() {
// let _: AssertParamIsClone<FieldTy>;
assert_ty_bounds(cx, stmts, field.ty.clone(), field.span, "AssertParamIsClone");
super::assert_ty_bounds(
cx,
&mut stmts,
field.ty.clone(),
field.span,
&[sym::clone, sym::AssertParamIsClone],
);
}
}
};

let mut stmts = Vec::new();
if is_union {
// let _: AssertParamIsCopy<Self>;
let self_ty = cx.ty_path(cx.path_ident(trait_span, Ident::with_dummy_span(kw::SelfUpper)));
assert_ty_bounds(cx, &mut stmts, self_ty, trait_span, "AssertParamIsCopy");
super::assert_ty_bounds(
cx,
&mut stmts,
self_ty,
trait_span,
&[sym::clone, sym::AssertParamIsCopy],
);
} else {
match *substr.fields {
StaticStruct(vdata, ..) => {
process_variant(cx, &mut stmts, vdata);
process_variant(vdata);
}
StaticEnum(enum_def, ..) => {
for variant in &enum_def.variants {
process_variant(cx, &mut stmts, &variant.data);
process_variant(&variant.data);
}
}
_ => cx.span_bug(
Expand Down
44 changes: 14 additions & 30 deletions compiler/rustc_builtin_macros/src/deriving/cmp/eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use crate::deriving::generic::*;
use crate::deriving::path_std;

use rustc_ast::ptr::P;
use rustc_ast::{self as ast, Expr, GenericArg, MetaItem};
use rustc_ast::{self as ast, Expr, MetaItem};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::symbol::{sym, Ident, Symbol};
use rustc_span::symbol::{sym, Ident};
use rustc_span::Span;

pub fn expand_deriving_eq(
Expand Down Expand Up @@ -55,43 +55,27 @@ fn cs_total_eq_assert(
trait_span: Span,
substr: &Substructure<'_>,
) -> P<Expr> {
fn assert_ty_bounds(
cx: &mut ExtCtxt<'_>,
stmts: &mut Vec<ast::Stmt>,
ty: P<ast::Ty>,
span: Span,
helper_name: &str,
) {
// Generate statement `let _: helper_name<ty>;`,
// set the expn ID so we can use the unstable struct.
let span = cx.with_def_site_ctxt(span);
let assert_path = cx.path_all(
span,
true,
cx.std_path(&[sym::cmp, Symbol::intern(helper_name)]),
vec![GenericArg::Type(ty)],
);
stmts.push(cx.stmt_let_type_only(span, cx.ty_path(assert_path)));
}
fn process_variant(
cx: &mut ExtCtxt<'_>,
stmts: &mut Vec<ast::Stmt>,
variant: &ast::VariantData,
) {
let mut stmts = Vec::new();
let mut process_variant = |variant: &ast::VariantData| {
for field in variant.fields() {
// let _: AssertParamIsEq<FieldTy>;
assert_ty_bounds(cx, stmts, field.ty.clone(), field.span, "AssertParamIsEq");
super::assert_ty_bounds(
cx,
&mut stmts,
field.ty.clone(),
field.span,
&[sym::cmp, sym::AssertParamIsEq],
);
}
}
};

let mut stmts = Vec::new();
match *substr.fields {
StaticStruct(vdata, ..) => {
process_variant(cx, &mut stmts, vdata);
process_variant(vdata);
}
StaticEnum(enum_def, ..) => {
for variant in &enum_def.variants {
process_variant(cx, &mut stmts, &variant.data);
process_variant(&variant.data);
}
}
_ => cx.span_bug(trait_span, "unexpected substructure in `derive(Eq)`"),
Expand Down
109 changes: 36 additions & 73 deletions compiler/rustc_builtin_macros/src/deriving/generic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1126,75 +1126,43 @@ impl<'a> MethodDef<'a> {
/// A1,
/// A2(i32)
/// }
///
/// // is equivalent to
///
/// impl PartialEq for A {
/// ```
/// is equivalent to:
/// ```
/// impl ::core::cmp::PartialEq for A {
/// #[inline]
/// fn eq(&self, other: &A) -> bool {
/// use A::*;
/// match (&*self, &*other) {
/// (&A1, &A1) => true,
/// (&A2(ref self_0),
/// &A2(ref __arg_1_0)) => (*self_0).eq(&(*__arg_1_0)),
/// _ => {
/// let __self_vi = match *self { A1 => 0, A2(..) => 1 };
/// let __arg_1_vi = match *other { A1 => 0, A2(..) => 1 };
/// false
/// {
/// let __self_vi = ::core::intrinsics::discriminant_value(&*self);
/// let __arg_1_vi = ::core::intrinsics::discriminant_value(&*other);
/// if true && __self_vi == __arg_1_vi {
/// match (&*self, &*other) {
/// (&A::A2(ref __self_0), &A::A2(ref __arg_1_0)) =>
/// (*__self_0) == (*__arg_1_0),
/// _ => true,
/// }
/// } else {
/// false // catch-all handler
/// }
/// }
/// }
/// }
/// ```
///
/// (Of course `__self_vi` and `__arg_1_vi` are unused for
/// `PartialEq`, and those subcomputations will hopefully be removed
/// as their results are unused. The point of `__self_vi` and
/// `__arg_1_vi` is for `PartialOrd`; see #15503.)
fn expand_enum_method_body<'b>(
&self,
cx: &mut ExtCtxt<'_>,
trait_: &TraitDef<'b>,
enum_def: &'b EnumDef,
type_ident: Ident,
self_args: Vec<P<Expr>>,
nonself_args: &[P<Expr>],
) -> P<Expr> {
self.build_enum_match_tuple(cx, trait_, enum_def, type_ident, self_args, nonself_args)
}

/// Creates a match for a tuple of all `self_args`, where either all
/// variants match, or it falls into a catch-all for when one variant
/// does not match.

///
/// There are N + 1 cases because is a case for each of the N
/// variants where all of the variants match, and one catch-all for
/// when one does not match.

///
/// As an optimization we generate code which checks whether all variants
/// match first which makes llvm see that C-like enums can be compiled into
/// a simple equality check (for PartialEq).

///
/// The catch-all handler is provided access the variant index values
/// for each of the self-args, carried in precomputed variables.

/// ```{.text}
/// let __self0_vi = std::intrinsics::discriminant_value(&self);
/// let __self1_vi = std::intrinsics::discriminant_value(&arg1);
/// let __self2_vi = std::intrinsics::discriminant_value(&arg2);
///
/// if __self0_vi == __self1_vi && __self0_vi == __self2_vi && ... {
/// match (...) {
/// (Variant1, Variant1, ...) => Body1
/// (Variant2, Variant2, ...) => Body2,
/// ...
/// _ => ::core::intrinsics::unreachable()
/// }
/// }
/// else {
/// ... // catch-all remainder can inspect above variant index values.
/// }
/// ```
fn build_enum_match_tuple<'b>(
fn expand_enum_method_body<'b>(
&self,
cx: &mut ExtCtxt<'_>,
trait_: &TraitDef<'b>,
Expand Down Expand Up @@ -1392,37 +1360,32 @@ impl<'a> MethodDef<'a> {
//
// i.e., for `enum E<T> { A, B(1), C(T, T) }`, and a deriving
// with three Self args, builds three statements:
//
// ```
// let __self0_vi = std::intrinsics::discriminant_value(&self);
// let __self1_vi = std::intrinsics::discriminant_value(&arg1);
// let __self2_vi = std::intrinsics::discriminant_value(&arg2);
// let __self_vi = std::intrinsics::discriminant_value(&self);
// let __arg_1_vi = std::intrinsics::discriminant_value(&arg1);
// let __arg_2_vi = std::intrinsics::discriminant_value(&arg2);
// ```
let mut index_let_stmts: Vec<ast::Stmt> = Vec::with_capacity(vi_idents.len() + 1);

// We also build an expression which checks whether all discriminants are equal
// discriminant_test = __self0_vi == __self1_vi && __self0_vi == __self2_vi && ...
// We also build an expression which checks whether all discriminants are equal:
// `__self_vi == __arg_1_vi && __self_vi == __arg_2_vi && ...`
let mut discriminant_test = cx.expr_bool(span, true);

let mut first_ident = None;
for (&ident, self_arg) in iter::zip(&vi_idents, &self_args) {
for (i, (&ident, self_arg)) in iter::zip(&vi_idents, &self_args).enumerate() {
let self_addr = cx.expr_addr_of(span, self_arg.clone());
let variant_value =
deriving::call_intrinsic(cx, span, sym::discriminant_value, vec![self_addr]);
let let_stmt = cx.stmt_let(span, false, ident, variant_value);
index_let_stmts.push(let_stmt);

match first_ident {
Some(first) => {
let first_expr = cx.expr_ident(span, first);
let id = cx.expr_ident(span, ident);
let test = cx.expr_binary(span, BinOpKind::Eq, first_expr, id);
discriminant_test =
cx.expr_binary(span, BinOpKind::And, discriminant_test, test)
}
None => {
first_ident = Some(ident);
}
if i > 0 {
let id0 = cx.expr_ident(span, vi_idents[0]);
let id = cx.expr_ident(span, ident);
let test = cx.expr_binary(span, BinOpKind::Eq, id0, id);
discriminant_test = if i == 1 {
test
} else {
cx.expr_binary(span, BinOpKind::And, discriminant_test, test)
};
}
}

Expand Down Expand Up @@ -1453,7 +1416,7 @@ impl<'a> MethodDef<'a> {
// }
// }
// else {
// <delegated expression referring to __self0_vi, et al.>
// <delegated expression referring to __self_vi, et al.>
// }
let all_match = cx.expr_match(span, match_arg, match_arms);
let arm_expr = cx.expr_if(span, discriminant_test, all_match, Some(arm_expr));
Expand Down
15 changes: 14 additions & 1 deletion compiler/rustc_builtin_macros/src/deriving/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use rustc_ast as ast;
use rustc_ast::ptr::P;
use rustc_ast::{Impl, ItemKind, MetaItem};
use rustc_ast::{GenericArg, Impl, ItemKind, MetaItem};
use rustc_expand::base::{Annotatable, ExpandResult, ExtCtxt, MultiItemModifier};
use rustc_span::symbol::{sym, Ident, Symbol};
use rustc_span::Span;
Expand Down Expand Up @@ -193,3 +193,16 @@ fn inject_impl_of_structural_trait(

push(Annotatable::Item(newitem));
}

fn assert_ty_bounds(
cx: &mut ExtCtxt<'_>,
stmts: &mut Vec<ast::Stmt>,
ty: P<ast::Ty>,
span: Span,
assert_path: &[Symbol],
) {
// Generate statement `let _: assert_path<ty>;`.
let span = cx.with_def_site_ctxt(span);
let assert_path = cx.path_all(span, true, cx.std_path(assert_path), vec![GenericArg::Type(ty)]);
stmts.push(cx.stmt_let_type_only(span, cx.ty_path(assert_path)));
}
3 changes: 3 additions & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ symbols! {
Arguments,
AsMut,
AsRef,
AssertParamIsClone,
AssertParamIsCopy,
AssertParamIsEq,
AtomicBool,
AtomicI128,
AtomicI16,
Expand Down
Loading

0 comments on commit 126e3df

Please sign in to comment.