Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compare trait references in trait_duplication_in_bounds correctly #13493

Merged
merged 1 commit into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 49 additions & 58 deletions clippy_lints/src/trait_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@ use clippy_utils::source::{SpanRangeExt, snippet, snippet_with_applicability};
use clippy_utils::{SpanlessEq, SpanlessHash, is_from_proc_macro};
use core::hash::{Hash, Hasher};
use itertools::Itertools;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, IndexEntry};
use rustc_data_structures::unhash::UnhashMap;
use rustc_errors::Applicability;
use rustc_hir::def::Res;
use rustc_hir::{
GenericArg, GenericBound, Generics, Item, ItemKind, LangItem, Node, Path, PathSegment, PredicateOrigin, QPath,
GenericBound, Generics, Item, ItemKind, LangItem, Node, Path, PathSegment, PredicateOrigin, QPath,
TraitBoundModifier, TraitItem, TraitRef, Ty, TyKind, WherePredicate,
};
use rustc_lint::{LateContext, LateLintPass};
use rustc_session::impl_lint_pass;
use rustc_span::{BytePos, Span};
use std::collections::hash_map::Entry;

declare_clippy_lint! {
/// ### What it does
Expand Down Expand Up @@ -153,7 +152,10 @@ impl<'tcx> LateLintPass<'tcx> for TraitBounds {
.filter_map(get_trait_info_from_bound)
.for_each(|(trait_item_res, trait_item_segments, span)| {
if let Some(self_segments) = self_bounds_map.get(&trait_item_res) {
if SpanlessEq::new(cx).eq_path_segments(self_segments, trait_item_segments) {
if SpanlessEq::new(cx)
.paths_by_resolution()
.eq_path_segments(self_segments, trait_item_segments)
{
span_lint_and_help(
cx,
TRAIT_DUPLICATION_IN_BOUNDS,
Expand Down Expand Up @@ -302,7 +304,7 @@ impl TraitBounds {
}
}

fn check_trait_bound_duplication(cx: &LateContext<'_>, generics: &'_ Generics<'_>) {
fn check_trait_bound_duplication<'tcx>(cx: &LateContext<'tcx>, generics: &'_ Generics<'tcx>) {
if generics.span.from_expansion() {
return;
}
Expand All @@ -314,6 +316,7 @@ fn check_trait_bound_duplication(cx: &LateContext<'_>, generics: &'_ Generics<'_
// |
// collects each of these where clauses into a set keyed by generic name and comparable trait
// eg. (T, Clone)
#[expect(clippy::mutable_key_type)]
let where_predicates = generics
.predicates
.iter()
Expand Down Expand Up @@ -367,11 +370,27 @@ fn check_trait_bound_duplication(cx: &LateContext<'_>, generics: &'_ Generics<'_
}
}

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
struct ComparableTraitRef(Res, Vec<Res>);
impl Default for ComparableTraitRef {
fn default() -> Self {
Self(Res::Err, Vec::new())
struct ComparableTraitRef<'a, 'tcx> {
cx: &'a LateContext<'tcx>,
trait_ref: &'tcx TraitRef<'tcx>,
modifier: TraitBoundModifier,
}

impl PartialEq for ComparableTraitRef<'_, '_> {
fn eq(&self, other: &Self) -> bool {
self.modifier == other.modifier
&& SpanlessEq::new(self.cx)
.paths_by_resolution()
.eq_path(self.trait_ref.path, other.trait_ref.path)
}
}
impl Eq for ComparableTraitRef<'_, '_> {}
impl Hash for ComparableTraitRef<'_, '_> {
fn hash<H: Hasher>(&self, state: &mut H) {
let mut s = SpanlessHash::new(self.cx).paths_by_resolution();
s.hash_path(self.trait_ref.path);
state.write_u64(s.finish());
self.modifier.hash(state);
}
}

Expand All @@ -392,69 +411,41 @@ fn get_trait_info_from_bound<'a>(bound: &'a GenericBound<'_>) -> Option<(Res, &'
}
}

fn get_ty_res(ty: Ty<'_>) -> Option<Res> {
match ty.kind {
TyKind::Path(QPath::Resolved(_, path)) => Some(path.res),
TyKind::Path(QPath::TypeRelative(ty, _)) => get_ty_res(*ty),
_ => None,
}
}

// FIXME: ComparableTraitRef does not support nested bounds needed for associated_type_bounds
fn into_comparable_trait_ref(trait_ref: &TraitRef<'_>) -> ComparableTraitRef {
ComparableTraitRef(
trait_ref.path.res,
trait_ref
.path
.segments
.iter()
.filter_map(|segment| {
// get trait bound type arguments
Some(segment.args?.args.iter().filter_map(|arg| {
if let GenericArg::Type(ty) = arg {
return get_ty_res(**ty);
}
None
}))
})
.flatten()
.collect(),
)
}

fn rollup_traits(
cx: &LateContext<'_>,
bounds: &[GenericBound<'_>],
fn rollup_traits<'cx, 'tcx>(
cx: &'cx LateContext<'tcx>,
bounds: &'tcx [GenericBound<'tcx>],
msg: &'static str,
) -> Vec<(ComparableTraitRef, Span)> {
let mut map = FxHashMap::default();
) -> Vec<(ComparableTraitRef<'cx, 'tcx>, Span)> {
// Source order is needed for joining spans
let mut map = FxIndexMap::default();
let mut repeated_res = false;

let only_comparable_trait_refs = |bound: &GenericBound<'_>| {
if let GenericBound::Trait(t, _) = bound {
Some((into_comparable_trait_ref(&t.trait_ref), t.span))
let only_comparable_trait_refs = |bound: &'tcx GenericBound<'tcx>| {
if let GenericBound::Trait(t, modifier) = bound {
Some((
ComparableTraitRef {
cx,
trait_ref: &t.trait_ref,
modifier: *modifier,
},
t.span,
))
} else {
None
}
};

let mut i = 0usize;
for bound in bounds.iter().filter_map(only_comparable_trait_refs) {
let (comparable_bound, span_direct) = bound;
match map.entry(comparable_bound) {
Entry::Occupied(_) => repeated_res = true,
Entry::Vacant(e) => {
e.insert((span_direct, i));
i += 1;
IndexEntry::Occupied(_) => repeated_res = true,
IndexEntry::Vacant(e) => {
e.insert(span_direct);
},
}
}

// Put bounds in source order
let mut comparable_bounds = vec![Default::default(); map.len()];
for (k, (v, i)) in map {
comparable_bounds[i] = (k, v);
}
let comparable_bounds: Vec<_> = map.into_iter().collect();

if repeated_res && let [first_trait, .., last_trait] = bounds {
let all_trait_span = first_trait.span().to(last_trait.span());
Expand Down
119 changes: 106 additions & 13 deletions clippy_utils/src/hir_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::tokenize_with_text;
use rustc_ast::ast::InlineAsmTemplatePiece;
use rustc_data_structures::fx::FxHasher;
use rustc_hir::MatchSource::TryDesugar;
use rustc_hir::def::Res;
use rustc_hir::def::{DefKind, Res};
use rustc_hir::{
ArrayLen, AssocItemConstraint, BinOpKind, BindingMode, Block, BodyId, Closure, ConstArg, ConstArgKind, Expr,
ExprField, ExprKind, FnRetTy, GenericArg, GenericArgs, HirId, HirIdMap, InlineAsmOperand, LetExpr, Lifetime,
Expand All @@ -17,11 +17,33 @@ use rustc_middle::ty::TypeckResults;
use rustc_span::{BytePos, ExpnKind, MacroKind, Symbol, SyntaxContext, sym};
use std::hash::{Hash, Hasher};
use std::ops::Range;
use std::slice;

/// Callback that is called when two expressions are not equal in the sense of `SpanlessEq`, but
/// other conditions would make them equal.
type SpanlessEqCallback<'a> = dyn FnMut(&Expr<'_>, &Expr<'_>) -> bool + 'a;

/// Determines how paths are hashed and compared for equality.
#[derive(Copy, Clone, Debug, Default)]
pub enum PathCheck {
/// Paths must match exactly and are hashed by their exact HIR tree.
///
/// Thus, `std::iter::Iterator` and `Iterator` are not considered equal even though they refer
/// to the same item.
#[default]
Exact,
/// Paths are compared and hashed based on their resolution.
///
/// They can appear different in the HIR tree but are still considered equal
/// and have equal hashes as long as they refer to the same item.
///
/// Note that this is currently only partially implemented specifically for paths that are
/// resolved before type-checking, i.e. the final segment must have a non-error resolution.
/// If a path with an error resolution is encountered, it falls back to the default exact
/// matching behavior.
Resolution,
}

/// Type used to check whether two ast are the same. This is different from the
/// operator `==` on ast types as this operator would compare true equality with
/// ID and span.
Expand All @@ -33,6 +55,7 @@ pub struct SpanlessEq<'a, 'tcx> {
maybe_typeck_results: Option<(&'tcx TypeckResults<'tcx>, &'tcx TypeckResults<'tcx>)>,
allow_side_effects: bool,
expr_fallback: Option<Box<SpanlessEqCallback<'a>>>,
path_check: PathCheck,
}

impl<'a, 'tcx> SpanlessEq<'a, 'tcx> {
Expand All @@ -42,6 +65,7 @@ impl<'a, 'tcx> SpanlessEq<'a, 'tcx> {
maybe_typeck_results: cx.maybe_typeck_results().map(|x| (x, x)),
allow_side_effects: true,
expr_fallback: None,
path_check: PathCheck::default(),
}
}

Expand All @@ -54,6 +78,16 @@ impl<'a, 'tcx> SpanlessEq<'a, 'tcx> {
}
}

/// Check paths by their resolution instead of exact equality. See [`PathCheck`] for more
/// details.
#[must_use]
pub fn paths_by_resolution(self) -> Self {
Self {
path_check: PathCheck::Resolution,
..self
}
}

#[must_use]
pub fn expr_fallback(self, expr_fallback: impl FnMut(&Expr<'_>, &Expr<'_>) -> bool + 'a) -> Self {
Self {
Expand Down Expand Up @@ -498,7 +532,7 @@ impl HirEqInterExpr<'_, '_, '_> {
match (left.res, right.res) {
(Res::Local(l), Res::Local(r)) => l == r || self.locals.get(&l) == Some(&r),
(Res::Local(_), _) | (_, Res::Local(_)) => false,
_ => over(left.segments, right.segments, |l, r| self.eq_path_segment(l, r)),
_ => self.eq_path_segments(left.segments, right.segments),
}
}

Expand All @@ -511,17 +545,39 @@ impl HirEqInterExpr<'_, '_, '_> {
}
}

pub fn eq_path_segments(&mut self, left: &[PathSegment<'_>], right: &[PathSegment<'_>]) -> bool {
left.len() == right.len() && left.iter().zip(right).all(|(l, r)| self.eq_path_segment(l, r))
pub fn eq_path_segments<'tcx>(
&mut self,
mut left: &'tcx [PathSegment<'tcx>],
mut right: &'tcx [PathSegment<'tcx>],
) -> bool {
if let PathCheck::Resolution = self.inner.path_check
&& let Some(left_seg) = generic_path_segments(left)
&& let Some(right_seg) = generic_path_segments(right)
{
// If we compare by resolution, then only check the last segments that could possibly have generic
// arguments
left = left_seg;
right = right_seg;
}

over(left, right, |l, r| self.eq_path_segment(l, r))
}

pub fn eq_path_segment(&mut self, left: &PathSegment<'_>, right: &PathSegment<'_>) -> bool {
// The == of idents doesn't work with different contexts,
// we have to be explicit about hygiene
left.ident.name == right.ident.name
&& both(left.args.as_ref(), right.args.as_ref(), |l, r| {
self.eq_path_parameters(l, r)
})
if !self.eq_path_parameters(left.args(), right.args()) {
return false;
}

if let PathCheck::Resolution = self.inner.path_check
&& left.res != Res::Err
&& right.res != Res::Err
{
left.res == right.res
} else {
// The == of idents doesn't work with different contexts,
// we have to be explicit about hygiene
left.ident.name == right.ident.name
}
}

pub fn eq_ty(&mut self, left: &Ty<'_>, right: &Ty<'_>) -> bool {
Expand Down Expand Up @@ -684,6 +740,21 @@ pub fn eq_expr_value(cx: &LateContext<'_>, left: &Expr<'_>, right: &Expr<'_>) ->
SpanlessEq::new(cx).deny_side_effects().eq_expr(left, right)
}

/// Returns the segments of a path that might have generic parameters.
/// Usually just the last segment for free items, except for when the path resolves to an associated
/// item, in which case it is the last two
fn generic_path_segments<'tcx>(segments: &'tcx [PathSegment<'tcx>]) -> Option<&'tcx [PathSegment<'tcx>]> {
match segments.last()?.res {
Res::Def(DefKind::AssocConst | DefKind::AssocFn | DefKind::AssocTy, _) => {
// <Ty as module::Trait<T>>::assoc::<U>
// ^^^^^^^^^^^^^^^^ ^^^^^^^^^^ segments: [module, Trait<T>, assoc<U>]
Some(&segments[segments.len().checked_sub(2)?..])
},
Res::Err => None,
_ => Some(slice::from_ref(segments.last()?)),
}
}

/// Type used to hash an ast element. This is different from the `Hash` trait
/// on ast types as this
/// trait would consider IDs and spans.
Expand All @@ -694,17 +765,29 @@ pub struct SpanlessHash<'a, 'tcx> {
cx: &'a LateContext<'tcx>,
maybe_typeck_results: Option<&'tcx TypeckResults<'tcx>>,
s: FxHasher,
path_check: PathCheck,
}

impl<'a, 'tcx> SpanlessHash<'a, 'tcx> {
pub fn new(cx: &'a LateContext<'tcx>) -> Self {
Self {
cx,
maybe_typeck_results: cx.maybe_typeck_results(),
path_check: PathCheck::default(),
s: FxHasher::default(),
}
}

/// Check paths by their resolution instead of exact equality. See [`PathCheck`] for more
/// details.
#[must_use]
pub fn paths_by_resolution(self) -> Self {
Self {
path_check: PathCheck::Resolution,
..self
}
}

pub fn finish(self) -> u64 {
self.s.finish()
}
Expand Down Expand Up @@ -1042,9 +1125,19 @@ impl<'a, 'tcx> SpanlessHash<'a, 'tcx> {
// even though the binding names are different and they have different `HirId`s.
Res::Local(_) => 1_usize.hash(&mut self.s),
_ => {
for seg in path.segments {
self.hash_name(seg.ident.name);
self.hash_generic_args(seg.args().args);
if let PathCheck::Resolution = self.path_check
&& let [.., last] = path.segments
&& let Some(segments) = generic_path_segments(path.segments)
{
for seg in segments {
self.hash_generic_args(seg.args().args);
}
last.res.hash(&mut self.s);
} else {
for seg in path.segments {
self.hash_name(seg.ident.name);
self.hash_generic_args(seg.args().args);
}
}
},
}
Expand Down
Loading