Skip to content

Commit

Permalink
Fix substring marker expression disjointness checks (#4998)
Browse files Browse the repository at this point in the history
## Summary

Noticed a bug here, `'a' in env` and `env not in 'a'` are not disjoint
given `env == 'ab'`.
  • Loading branch information
ibraheemdev authored Jul 12, 2024
1 parent abdb58d commit a1f71a3
Showing 1 changed file with 62 additions and 12 deletions.
74 changes: 62 additions & 12 deletions crates/uv-resolver/src/marker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,47 @@ pub(crate) fn is_disjoint(first: &MarkerTree, second: &MarkerTree) -> bool {
fn string_is_disjoint(this: &MarkerExpression, other: &MarkerExpression) -> bool {
use MarkerOperator::*;

let (key, operator, value) = extract_string_expression(this).unwrap();
// The `in` and `not in` operators are not reversible, so we have to ensure the expressions
// match exactly. Notably, `'a' in env` and `env not in 'a'` are not disjoint given `env == 'ab'`.
match (this, other) {
(
MarkerExpression::String {
key,
operator,
value,
},
MarkerExpression::String {
key: key2,
operator: operator2,
value: value2,
},
)
| (
MarkerExpression::StringInverted {
key,
operator,
value,
},
MarkerExpression::StringInverted {
key: key2,
operator: operator2,
value: value2,
},
) if key == key2 => match (operator, operator2) {
// The only disjoint expressions involving these operators are `key in value`
// and `key not in value`, or reversed.
(In, NotIn) | (NotIn, In) => return value == value2,
// Anything else cannot be disjoint.
(In | NotIn, _) | (_, In | NotIn) => return false,
_ => {}
},
_ => {}
}

// Extract the normalized string expressions.
let Some((key, operator, value)) = extract_string_expression(this) else {
return false;
};
let Some((key2, operator2, value2)) = extract_string_expression(other) else {
return false;
};
Expand All @@ -63,9 +103,6 @@ fn string_is_disjoint(this: &MarkerExpression, other: &MarkerExpression) -> bool
// The only disjoint expressions involving strict inequality are `key != value` and `key == value`.
(NotEqual, Equal) | (Equal, NotEqual) => return value == value2,
(NotEqual, _) | (_, NotEqual) => return false,
// Similarly for `in` and `not in`.
(In, NotIn) | (NotIn, In) => return value == value2,
(In | NotIn, _) | (_, In | NotIn) => return false,
_ => {}
}

Expand Down Expand Up @@ -560,8 +597,8 @@ fn extract_string_expression(
operator,
key,
} => {
// if the expression was inverted, we have to reverse the operator
Some((key, reverse_marker_operator(*operator), value))
// If the expression was inverted, we have to reverse the marker operator.
reverse_marker_operator(*operator).map(|operator| (key, operator, value.as_str()))
}
_ => None,
}
Expand Down Expand Up @@ -675,6 +712,7 @@ fn keyed_range(expr: &MarkerExpression) -> Option<(&MarkerValueVersion, PubGrubR
/// Reverses a binary operator.
fn reverse_operator(operator: Operator) -> Operator {
use Operator::*;

match operator {
LessThan => GreaterThan,
LessThanEqual => GreaterThanEqual,
Expand All @@ -684,16 +722,21 @@ fn reverse_operator(operator: Operator) -> Operator {
}
}

/// Reverses a marker operator.
fn reverse_marker_operator(operator: MarkerOperator) -> MarkerOperator {
/// Reverses a marker operator, if possible.
fn reverse_marker_operator(operator: MarkerOperator) -> Option<MarkerOperator> {
use MarkerOperator::*;
match operator {

Some(match operator {
LessThan => GreaterThan,
LessEqual => GreaterEqual,
GreaterThan => LessThan,
GreaterEqual => LessEqual,
_ => operator,
}
Equal => Equal,
NotEqual => NotEqual,
TildeEqual => TildeEqual,
// The `in` and `not in` operators are not reversible.
In | NotIn => return None,
})
}

#[cfg(test)]
Expand Down Expand Up @@ -936,7 +979,14 @@ mod tests {
"os_name in 'Windows'",
"os_name not in 'Windows'"
));
assert!(is_disjoint("'Linux' in os_name", "os_name not in 'Linux'"));
assert!(is_disjoint(
"'Windows' in os_name",
"'Windows' not in os_name"
));

assert!(!is_disjoint("'Windows' in os_name", "'Windows' in os_name"));
assert!(!is_disjoint("'Linux' in os_name", "os_name not in 'Linux'"));
assert!(!is_disjoint("'Linux' not in os_name", "os_name in 'Linux'"));
}

#[test]
Expand Down

0 comments on commit a1f71a3

Please sign in to comment.