Skip to content

Commit

Permalink
Allow repeated-equality-comparison for mixed operations (#12369)
Browse files Browse the repository at this point in the history
## Summary

This PR allows us to fix both expressions in `foo == "a" or foo == "b"
or ("c" != bar and "d" != bar)`, but limits the rule to consecutive
comparisons, following #7797.

I think this logic was _probably_ added because of
#12368 -- the intent being that
we'd replace the _entire_ expression.
  • Loading branch information
charliermarsh authored Jul 18, 2024
1 parent 9b9d701 commit 764d9ab
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,7 @@
foo == "a" or ("c" == bar or "d" == bar) or foo == "b" # Multiple targets

foo == "a" or foo == "b" or "c" != bar and "d" != bar # Multiple targets

foo == "a" or ("c" != bar and "d" != bar) or foo == "b" # Multiple targets

foo == "a" and "c" != bar or foo == "b" and "d" != bar # Multiple targets
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::ops::Deref;

use itertools::Itertools;
use rustc_hash::{FxBuildHasher, FxHashMap};

Expand Down Expand Up @@ -72,79 +70,83 @@ impl AlwaysFixableViolation for RepeatedEqualityComparison {

/// PLR1714
pub(crate) fn repeated_equality_comparison(checker: &mut Checker, bool_op: &ast::ExprBoolOp) {
if bool_op
.values
.iter()
.any(|value| !is_allowed_value(bool_op.op, value, checker.semantic()))
{
return;
}

// Map from expression hash to (starting offset, number of comparisons, list
let mut value_to_comparators: FxHashMap<HashableExpr, (TextSize, Vec<&Expr>, Vec<&Expr>)> =
let mut value_to_comparators: FxHashMap<HashableExpr, (TextSize, Vec<&Expr>, Vec<usize>)> =
FxHashMap::with_capacity_and_hasher(bool_op.values.len() * 2, FxBuildHasher);

for value in &bool_op.values {
// Enforced via `is_allowed_value`.
let Expr::Compare(ast::ExprCompare {
left, comparators, ..
}) = value
else {
return;
};

// Enforced via `is_allowed_value`.
let [right] = &**comparators else {
return;
for (i, value) in bool_op.values.iter().enumerate() {
let Some((left, right)) = to_allowed_value(bool_op.op, value, checker.semantic()) else {
continue;
};

if matches!(left.as_ref(), Expr::Name(_) | Expr::Attribute(_)) {
let (_, left_matches, value_matches) = value_to_comparators
.entry(left.deref().into())
if matches!(left, Expr::Name(_) | Expr::Attribute(_)) {
let (_, left_matches, index_matches) = value_to_comparators
.entry(left.into())
.or_insert_with(|| (left.start(), Vec::new(), Vec::new()));
left_matches.push(right);
value_matches.push(value);
index_matches.push(i);
}

if matches!(right, Expr::Name(_) | Expr::Attribute(_)) {
let (_, right_matches, value_matches) = value_to_comparators
let (_, right_matches, index_matches) = value_to_comparators
.entry(right.into())
.or_insert_with(|| (right.start(), Vec::new(), Vec::new()));
right_matches.push(left);
value_matches.push(value);
index_matches.push(i);
}
}

for (value, (start, comparators, values)) in value_to_comparators
for (value, (_, comparators, indices)) in value_to_comparators
.iter()
.sorted_by_key(|(_, (start, _, _))| *start)
{
if comparators.len() > 1 {
// If there's only one comparison, there's nothing to merge.
if comparators.len() == 1 {
continue;
}

// Break into sequences of consecutive comparisons.
let mut sequences: Vec<(Vec<usize>, Vec<&Expr>)> = Vec::new();
let mut last = None;
for (index, comparator) in indices.iter().zip(comparators.iter()) {
if last.is_some_and(|last| last + 1 == *index) {
let (indices, comparators) = sequences.last_mut().unwrap();
indices.push(*index);
comparators.push(*comparator);
} else {
sequences.push((vec![*index], vec![*comparator]));
}
last = Some(*index);
}

for (indices, comparators) in sequences {
if indices.len() == 1 {
continue;
}

let mut diagnostic = Diagnostic::new(
RepeatedEqualityComparison {
expression: SourceCodeSnippet::new(merged_membership_test(
value.as_expr(),
bool_op.op,
comparators,
&comparators,
checker.locator(),
)),
},
bool_op.range(),
);

// Grab the remaining comparisons.
let (before, after) = bool_op
.values
.iter()
.filter(|value| !values.contains(value))
.partition::<Vec<_>, _>(|value| value.start() < *start);
let [first, .., last] = indices.as_slice() else {
unreachable!("Indices should have at least two elements")
};
let before = bool_op.values.iter().take(*first).cloned();
let after = bool_op.values.iter().skip(last + 1).cloned();

diagnostic.set_fix(Fix::unsafe_edit(Edit::range_replacement(
checker.generator().expr(&Expr::BoolOp(ast::ExprBoolOp {
op: bool_op.op,
values: before
.into_iter()
.cloned()
.chain(std::iter::once(Expr::Compare(ast::ExprCompare {
left: Box::new(value.as_expr().clone()),
ops: match bool_op.op {
Expand All @@ -159,7 +161,7 @@ pub(crate) fn repeated_equality_comparison(checker: &mut Checker, bool_op: &ast:
})]),
range: bool_op.range(),
})))
.chain(after.into_iter().cloned())
.chain(after)
.collect(),
range: bool_op.range(),
})),
Expand All @@ -174,39 +176,43 @@ pub(crate) fn repeated_equality_comparison(checker: &mut Checker, bool_op: &ast:
/// Return `true` if the given expression is compatible with a membership test.
/// E.g., `==` operators can be joined with `or` and `!=` operators can be
/// joined with `and`.
fn is_allowed_value(bool_op: BoolOp, value: &Expr, semantic: &SemanticModel) -> bool {
fn to_allowed_value<'a>(
bool_op: BoolOp,
value: &'a Expr,
semantic: &SemanticModel,
) -> Option<(&'a Expr, &'a Expr)> {
let Expr::Compare(ast::ExprCompare {
left,
ops,
comparators,
..
}) = value
else {
return false;
return None;
};

// Ignore, e.g., `foo == bar == baz`.
let [op] = &**ops else {
return false;
return None;
};

if match bool_op {
BoolOp::Or => !matches!(op, CmpOp::Eq),
BoolOp::And => !matches!(op, CmpOp::NotEq),
} {
return false;
return None;
}

// Ignore self-comparisons, e.g., `foo == foo`.
let [right] = &**comparators else {
return false;
return None;
};
if ComparableExpr::from(left) == ComparableExpr::from(right) {
return false;
return None;
}

if contains_effect(value, |id| semantic.has_builtin_binding(id)) {
return false;
return None;
}

// Ignore `sys.version_info` and `sys.platform` comparisons, which are only
Expand All @@ -221,10 +227,10 @@ fn is_allowed_value(bool_op: BoolOp, value: &Expr, semantic: &SemanticModel) ->
)
})
}) {
return false;
return None;
}

true
Some((left, right))
}

/// Generate a string like `obj in (a, b, c)` or `obj not in (a, b, c)`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub struct RepeatedIsinstanceCalls {
expression: SourceCodeSnippet,
}

// PLR1701
/// PLR1701
impl AlwaysFixableViolation for RepeatedIsinstanceCalls {
#[derive_message_formats]
fn message(&self) -> String {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,80 +292,85 @@ repeated_equality_comparison.py:26:1: PLR1714 [*] Consider merging multiple comp
28 28 | # OK
29 29 | foo == "a" and foo == "b" and foo == "c" # `and` mixed with `==`.

repeated_equality_comparison.py:59:1: PLR1714 [*] Consider merging multiple comparisons: `foo in ("a", "b")`. Use a `set` if the elements are hashable.
repeated_equality_comparison.py:61:16: PLR1714 [*] Consider merging multiple comparisons: `bar in ("c", "d")`. Use a `set` if the elements are hashable.
|
57 | sys.platform == "win32" or sys.platform == "emscripten" # sys attributes
58 |
59 | foo == "a" or "c" == bar or foo == "b" or "d" == bar # Multiple targets
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PLR1714
60 |
61 | foo == "a" or ("c" == bar or "d" == bar) or foo == "b" # Multiple targets
| ^^^^^^^^^^^^^^^^^^^^^^^^ PLR1714
62 |
63 | foo == "a" or foo == "b" or "c" != bar and "d" != bar # Multiple targets
|
= help: Merge multiple comparisons

Unsafe fix
56 56 |
57 57 | sys.platform == "win32" or sys.platform == "emscripten" # sys attributes
58 58 |
59 |-foo == "a" or "c" == bar or foo == "b" or "d" == bar # Multiple targets
59 |+foo in ("a", "b") or "c" == bar or "d" == bar # Multiple targets
59 59 | foo == "a" or "c" == bar or foo == "b" or "d" == bar # Multiple targets
60 60 |
61 61 | foo == "a" or ("c" == bar or "d" == bar) or foo == "b" # Multiple targets
61 |-foo == "a" or ("c" == bar or "d" == bar) or foo == "b" # Multiple targets
61 |+foo == "a" or (bar in ("c", "d")) or foo == "b" # Multiple targets
62 62 |
63 63 | foo == "a" or foo == "b" or "c" != bar and "d" != bar # Multiple targets
64 64 |

repeated_equality_comparison.py:59:1: PLR1714 [*] Consider merging multiple comparisons: `bar in ("c", "d")`. Use a `set` if the elements are hashable.
repeated_equality_comparison.py:63:1: PLR1714 [*] Consider merging multiple comparisons: `foo in ("a", "b")`. Use a `set` if the elements are hashable.
|
57 | sys.platform == "win32" or sys.platform == "emscripten" # sys attributes
58 |
59 | foo == "a" or "c" == bar or foo == "b" or "d" == bar # Multiple targets
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PLR1714
60 |
61 | foo == "a" or ("c" == bar or "d" == bar) or foo == "b" # Multiple targets
62 |
63 | foo == "a" or foo == "b" or "c" != bar and "d" != bar # Multiple targets
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PLR1714
64 |
65 | foo == "a" or ("c" != bar and "d" != bar) or foo == "b" # Multiple targets
|
= help: Merge multiple comparisons

Unsafe fix
56 56 |
57 57 | sys.platform == "win32" or sys.platform == "emscripten" # sys attributes
58 58 |
59 |-foo == "a" or "c" == bar or foo == "b" or "d" == bar # Multiple targets
59 |+foo == "a" or bar in ("c", "d") or foo == "b" # Multiple targets
60 60 |
61 61 | foo == "a" or ("c" == bar or "d" == bar) or foo == "b" # Multiple targets
62 62 |
63 |-foo == "a" or foo == "b" or "c" != bar and "d" != bar # Multiple targets
63 |+foo in ("a", "b") or "c" != bar and "d" != bar # Multiple targets
64 64 |
65 65 | foo == "a" or ("c" != bar and "d" != bar) or foo == "b" # Multiple targets
66 66 |

repeated_equality_comparison.py:61:16: PLR1714 [*] Consider merging multiple comparisons: `bar in ("c", "d")`. Use a `set` if the elements are hashable.
repeated_equality_comparison.py:63:29: PLR1714 [*] Consider merging multiple comparisons: `bar not in ("c", "d")`. Use a `set` if the elements are hashable.
|
59 | foo == "a" or "c" == bar or foo == "b" or "d" == bar # Multiple targets
60 |
61 | foo == "a" or ("c" == bar or "d" == bar) or foo == "b" # Multiple targets
| ^^^^^^^^^^^^^^^^^^^^^^^^ PLR1714
62 |
63 | foo == "a" or foo == "b" or "c" != bar and "d" != bar # Multiple targets
| ^^^^^^^^^^^^^^^^^^^^^^^^^ PLR1714
64 |
65 | foo == "a" or ("c" != bar and "d" != bar) or foo == "b" # Multiple targets
|
= help: Merge multiple comparisons

Unsafe fix
58 58 |
59 59 | foo == "a" or "c" == bar or foo == "b" or "d" == bar # Multiple targets
60 60 |
61 |-foo == "a" or ("c" == bar or "d" == bar) or foo == "b" # Multiple targets
61 |+foo == "a" or (bar in ("c", "d")) or foo == "b" # Multiple targets
61 61 | foo == "a" or ("c" == bar or "d" == bar) or foo == "b" # Multiple targets
62 62 |
63 63 | foo == "a" or foo == "b" or "c" != bar and "d" != bar # Multiple targets
63 |-foo == "a" or foo == "b" or "c" != bar and "d" != bar # Multiple targets
63 |+foo == "a" or foo == "b" or bar not in ("c", "d") # Multiple targets
64 64 |
65 65 | foo == "a" or ("c" != bar and "d" != bar) or foo == "b" # Multiple targets
66 66 |

repeated_equality_comparison.py:63:29: PLR1714 [*] Consider merging multiple comparisons: `bar not in ("c", "d")`. Use a `set` if the elements are hashable.
repeated_equality_comparison.py:65:16: PLR1714 [*] Consider merging multiple comparisons: `bar not in ("c", "d")`. Use a `set` if the elements are hashable.
|
61 | foo == "a" or ("c" == bar or "d" == bar) or foo == "b" # Multiple targets
62 |
63 | foo == "a" or foo == "b" or "c" != bar and "d" != bar # Multiple targets
| ^^^^^^^^^^^^^^^^^^^^^^^^^ PLR1714
64 |
65 | foo == "a" or ("c" != bar and "d" != bar) or foo == "b" # Multiple targets
| ^^^^^^^^^^^^^^^^^^^^^^^^^ PLR1714
66 |
67 | foo == "a" and "c" != bar or foo == "b" and "d" != bar # Multiple targets
|
= help: Merge multiple comparisons

Unsafe fix
60 60 |
61 61 | foo == "a" or ("c" == bar or "d" == bar) or foo == "b" # Multiple targets
62 62 |
63 |-foo == "a" or foo == "b" or "c" != bar and "d" != bar # Multiple targets
63 |+foo == "a" or foo == "b" or bar not in ("c", "d") # Multiple targets
63 63 | foo == "a" or foo == "b" or "c" != bar and "d" != bar # Multiple targets
64 64 |
65 |-foo == "a" or ("c" != bar and "d" != bar) or foo == "b" # Multiple targets
65 |+foo == "a" or (bar not in ("c", "d")) or foo == "b" # Multiple targets
66 66 |
67 67 | foo == "a" and "c" != bar or foo == "b" and "d" != bar # Multiple targets

0 comments on commit 764d9ab

Please sign in to comment.