diff --git a/crates/ty_python_semantic/resources/mdtest/binary/integers.md b/crates/ty_python_semantic/resources/mdtest/binary/integers.md index 30561981810b3..3834a90e0ba20 100644 --- a/crates/ty_python_semantic/resources/mdtest/binary/integers.md +++ b/crates/ty_python_semantic/resources/mdtest/binary/integers.md @@ -1,5 +1,11 @@ # Binary operations on integers +> Developer's note: This is mainly a test for the behavior of the type inferer. The constant +> evaluator (`resolve_to_literal`) of `SemanticIndexBuilder` is implemented separately from the type +> inferer, so if you modify the contents of this file or the type inferer, please also modify the +> implementation of `resolve_to_literal` and the unit tests (semantic_index/tests/const_eval\_\*) at +> the same time. + ## Basic Arithmetic ```py diff --git a/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md b/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md index e917582c5804a..8e88a68e63f9c 100644 --- a/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md +++ b/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md @@ -247,21 +247,45 @@ Here the loop condition forces `x` to be `False` at loop exit, because there is def random() -> bool: return True -x = random() -reveal_type(x) # revealed: bool -while x: - pass -reveal_type(x) # revealed: Literal[False] +def _(x: bool): + while x: + pass + reveal_type(x) # revealed: Literal[False] ``` However, we can't narrow `x` like this when there's a `break` in the loop: ```py -x = random() -while x: - if random(): +def _(x: bool): + while x: + if random(): + break + reveal_type(x) # revealed: bool + +def _(x: bool): + while x: + pass + reveal_type(x) # revealed: Literal[False] + + x = random() + while x: + if random(): + break + reveal_type(x) # revealed: bool + +def _(y: int | None): + x = 1 + while True: + if x == 0: + break + + if y is None: + y = 0 + continue + break -reveal_type(x) # revealed: bool + + reveal_type(y) # revealed: int ``` ### Non-static loop conditions diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md b/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md index 76d96d746baf1..d897e7836220b 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md @@ -180,14 +180,26 @@ def _(x: int | None): ``` ```py +from typing import Final + def _(x: int | None): if 1 + 1 == 2: if x is None: return reveal_type(x) # revealed: int - # TODO: should be `int` (the else-branch of `1 + 1 == 2` is unreachable) - reveal_type(x) # revealed: int | None + reveal_type(x) # revealed: int + +# non-constant but always-true condition +needs_inference: Final = True + +def _(x: int | None): + if needs_inference: + if x is None: + return + reveal_type(x) # revealed: int + + reveal_type(x) # revealed: int ``` This also works when the always-true condition is nested inside a narrowing branch: @@ -198,9 +210,14 @@ def _(x: int | None): if 1 + 1 == 2: return - # TODO: should be `int` (the inner always-true branch makes the outer - # if-branch terminal) - reveal_type(x) # revealed: int | None + reveal_type(x) # revealed: int + +def _(x: int | None): + if x is None: + if needs_inference: + return + + reveal_type(x) # revealed: int ``` ## Narrowing from `assert` should not affect reassigned variables diff --git a/crates/ty_python_semantic/src/semantic_index.rs b/crates/ty_python_semantic/src/semantic_index.rs index e49d3188104ec..d64bde0cef810 100644 --- a/crates/ty_python_semantic/src/semantic_index.rs +++ b/crates/ty_python_semantic/src/semantic_index.rs @@ -925,6 +925,50 @@ mod tests { .collect() } + /// A function to test how the constant evaluator of `SemanticIndexBuilder` evaluates an expression + /// (the evaluation should match that of `TypeInferenceBuilder`). + /// For example, for the input `x = 1\nif cond: x = 2\nx`, if `cond` evaluates to `AlwaysTrue`, it returns `vec![2]`, + /// if it evaluates to `AlwaysFalse`, it returns `vec![1]`, ​​if it evaluates to `Ambiguous`, it returns `vec![1, 2]`. + fn reachable_bindings_for_terminal_use(content: &str) -> Vec { + let TestCase { db, file } = test_case(content); + let scope = global_scope(&db, file); + let module = parsed_module(&db, file).load(&db); + let ast = module.syntax(); + + let terminal_expr = ast + .body + .last() + .and_then(ast::Stmt::as_expr_stmt) + .map(|stmt| stmt.value.as_ref()) + .expect("expected terminal expression statement"); + let terminal_name = terminal_expr + .as_name_expr() + .expect("terminal expression should be a name"); + + let use_id = terminal_name.scoped_use_id(&db, scope); + let use_def = use_def_map(&db, scope); + + use_def + .bindings_at_use(use_id) + .filter_map(|binding_with_constraints| { + let definition = binding_with_constraints.binding.definition()?; + let DefinitionKind::Assignment(assignment) = definition.kind(&db) else { + return None; + }; + + let ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + value: ast::Number::Int(value), + .. + }) = assignment.value(&module) + else { + return None; + }; + + value.as_i64() + }) + .collect::>() + } + #[test] fn empty() { let TestCase { db, file } = test_case(""); @@ -1590,6 +1634,71 @@ class C[T]: assert_eq!(*num, 1); } + #[test] + fn const_eval_lshift_overflow_is_ambiguous() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if 1 << 63: + x = 2 +x +", + ); + assert_eq!(values, vec![1, 2]); + } + + #[test] + fn const_eval_lshift_zero_short_circuit() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if 0 << 4000000000000000000: + x = 2 +x +", + ); + assert_eq!(values, vec![1]); + } + + #[test] + fn const_eval_rshift_large_positive() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if 1 >> 5000000000: + x = 2 +x +", + ); + assert_eq!(values, vec![1]); + } + + #[test] + fn const_eval_rshift_large_negative_operand() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if (-1) >> 5000000000: + x = 2 +x +", + ); + assert_eq!(values, vec![2]); + } + + #[test] + fn const_eval_negative_lshift_is_ambiguous() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if 42 << -3: + x = 2 +x +", + ); + assert_eq!(values, vec![1, 2]); + } + #[test] fn expression_scope() { let TestCase { db, file } = test_case("x = 1;\ndef test():\n y = 4"); diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 76cb3c3df1ef3..2c08129318283 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -908,27 +908,253 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { } fn build_predicate(&mut self, predicate_node: &ast::Expr) -> PredicateOrLiteral<'db> { + /// Returns if the expression is a `TYPE_CHECKING` expression. + fn is_if_type_checking(expr: &ast::Expr) -> bool { + fn is_dotted_name(expr: &ast::Expr) -> bool { + match expr { + ast::Expr::Name(_) => true, + ast::Expr::Attribute(ast::ExprAttribute { value, .. }) => is_dotted_name(value), + _ => false, + } + } + + match expr { + ast::Expr::Name(ast::ExprName { id, .. }) => id == "TYPE_CHECKING", + ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { + attr == "TYPE_CHECKING" && is_dotted_name(value) + } + _ => false, + } + } + // Some commonly used test expressions are eagerly evaluated as `true` // or `false` here for performance reasons. This list does not need to // be exhaustive. More complex expressions will still evaluate to the // correct value during type-checking. fn resolve_to_literal(node: &ast::Expr) -> Option { - match node { - ast::Expr::BooleanLiteral(ast::ExprBooleanLiteral { value, .. }) => Some(*value), - ast::Expr::Name(ast::ExprName { id, .. }) if id == "TYPE_CHECKING" => Some(true), - ast::Expr::NumberLiteral(ast::ExprNumberLiteral { - value: ast::Number::Int(n), - .. - }) => Some(*n != 0), - ast::Expr::EllipsisLiteral(_) => Some(true), - ast::Expr::NoneLiteral(_) => Some(false), - ast::Expr::UnaryOp(ast::ExprUnaryOp { - op: ast::UnaryOp::Not, - operand, - .. - }) => Some(!resolve_to_literal(operand)?), - _ => None, + #[derive(Copy, Clone)] + enum ConstExpr { + Bool(bool), + Int(i64), + None, + Ellipsis, + } + + impl ConstExpr { + fn truthiness(self) -> bool { + match self { + ConstExpr::Bool(value) => value, + ConstExpr::Int(value) => value != 0, + ConstExpr::None => false, + ConstExpr::Ellipsis => true, + } + } + + fn as_int(self) -> Option { + match self { + ConstExpr::Int(value) => Some(value), + ConstExpr::Bool(value) => Some(i64::from(value)), + _ => None, + } + } + } + + fn resolve_const_expr(node: &ast::Expr) -> Option { + match node { + ast::Expr::BooleanLiteral(ast::ExprBooleanLiteral { value, .. }) => { + Some(ConstExpr::Bool(*value)) + } + ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + value: ast::Number::Int(n), + .. + }) => n.as_i64().map(ConstExpr::Int), + ast::Expr::EllipsisLiteral(_) => Some(ConstExpr::Ellipsis), + ast::Expr::NoneLiteral(_) => Some(ConstExpr::None), + // See also: `TypeInferenceBuilder::infer_unary_expression_type` + ast::Expr::UnaryOp(ast::ExprUnaryOp { op, operand, .. }) => { + let operand = resolve_const_expr(operand)?; + match op { + ast::UnaryOp::Not => Some(ConstExpr::Bool(!operand.truthiness())), + ast::UnaryOp::UAdd => Some(ConstExpr::Int(operand.as_int()?)), + ast::UnaryOp::USub => { + Some(ConstExpr::Int(operand.as_int()?.checked_neg()?)) + } + ast::UnaryOp::Invert => Some(ConstExpr::Int(!operand.as_int()?)), + } + } + // See also: `TypeInferenceBuilder::infer_binary_expression_type` + ast::Expr::BinOp(ast::ExprBinOp { + left, op, right, .. + }) => { + let left = resolve_const_expr(left)?.as_int()?; + let right = resolve_const_expr(right)?.as_int()?; + let value = match op { + ast::Operator::Add => left.checked_add(right)?, + ast::Operator::Sub => left.checked_sub(right)?, + ast::Operator::Mult => left.checked_mul(right)?, + ast::Operator::FloorDiv => { + let mut q = left.checked_div(right); + let r = left.checked_rem(right); + // Division works differently in Python than in Rust. If the + // result is negative and there is a remainder, floor division + // rounds down (instead of toward zero). + if left.is_negative() != right.is_negative() && r.unwrap_or(0) != 0 + { + q = q.map(|q| q - 1); + } + q? + } + ast::Operator::Mod => { + let mut r = left.checked_rem(right); + // Python's modulo keeps the sign of the divisor. Adjust the Rust + // remainder accordingly so that `q * right + r == left`. + if left.is_negative() != right.is_negative() && r.unwrap_or(0) != 0 + { + r = r.map(|x| x + right); + } + r? + } + ast::Operator::BitAnd => left & right, + ast::Operator::BitOr => left | right, + ast::Operator::BitXor => left ^ right, + ast::Operator::LShift => { + if left == 0 && right >= 0 { + 0 + } else { + // An additional overflow check beyond `checked_shl` is + // necessary here, because `checked_shl` only rejects shift + // amounts >= 64; it does not detect when significant bits + // are shifted into (or past) the sign bit. + // + // We compute the "headroom": the number of redundant + // sign-extension bits minus one (for the sign bit itself). + // A shift is safe iff `shift <= headroom`. + let headroom = if left >= 0 { + left.leading_zeros().saturating_sub(1) + } else { + left.leading_ones().saturating_sub(1) + }; + u32::try_from(right) + .ok() + .filter(|&shift| shift <= headroom) + .and_then(|shift| left.checked_shl(shift))? + } + } + ast::Operator::RShift => match u32::try_from(right) { + Ok(shift) => left >> shift.clamp(0, 63), + Err(_) if right > 0 => { + if left >= 0 { + 0 + } else { + -1 + } + } + Err(_) => return None, + }, + ast::Operator::Pow => { + let exp = u32::try_from(right).ok()?; + left.checked_pow(exp)? + } + ast::Operator::Div | ast::Operator::MatMult => return None, + }; + Some(ConstExpr::Int(value)) + } + ast::Expr::BoolOp(ast::ExprBoolOp { op, values, .. }) => { + let value = match op { + ast::BoolOp::And => { + let mut all_true = true; + for expr in values { + if !resolve_const_expr(expr)?.truthiness() { + all_true = false; + break; + } + } + all_true + } + ast::BoolOp::Or => { + let mut any_true = false; + for expr in values { + if resolve_const_expr(expr)?.truthiness() { + any_true = true; + break; + } + } + any_true + } + }; + Some(ConstExpr::Bool(value)) + } + ast::Expr::Compare(ast::ExprCompare { + left, + ops, + comparators, + .. + }) => { + let mut left_value = resolve_const_expr(left)?; + for (op, comparator) in ops.iter().zip(comparators.iter()) { + let right_value = resolve_const_expr(comparator)?; + let eq = |left: ConstExpr, right: ConstExpr| match (left, right) { + (ConstExpr::Int(left), ConstExpr::Int(right)) => { + Some(left == right) + } + (ConstExpr::None, ConstExpr::None) + | (ConstExpr::Ellipsis, ConstExpr::Ellipsis) => Some(true), + (ConstExpr::None | ConstExpr::Ellipsis, _) + | (_, ConstExpr::None | ConstExpr::Ellipsis) => Some(false), + _ => None, + }; + let result = match op { + ast::CmpOp::Eq => eq(left_value, right_value)?, + ast::CmpOp::NotEq => !eq(left_value, right_value)?, + ast::CmpOp::Lt => left_value.as_int()? < right_value.as_int()?, + ast::CmpOp::LtE => left_value.as_int()? <= right_value.as_int()?, + ast::CmpOp::Gt => left_value.as_int()? > right_value.as_int()?, + ast::CmpOp::GtE => left_value.as_int()? >= right_value.as_int()?, + ast::CmpOp::Is => match (left_value, right_value) { + (ConstExpr::None, ConstExpr::None) + | (ConstExpr::Ellipsis, ConstExpr::Ellipsis) + | (ConstExpr::Bool(true), ConstExpr::Bool(true)) + | (ConstExpr::Bool(false), ConstExpr::Bool(false)) => true, + ( + ConstExpr::None | ConstExpr::Ellipsis | ConstExpr::Bool(_), + _, + ) + | ( + _, + ConstExpr::None | ConstExpr::Ellipsis | ConstExpr::Bool(_), + ) => false, + _ => return None, + }, + ast::CmpOp::IsNot => match (left_value, right_value) { + (ConstExpr::None, ConstExpr::None) + | (ConstExpr::Ellipsis, ConstExpr::Ellipsis) + | (ConstExpr::Bool(true), ConstExpr::Bool(true)) + | (ConstExpr::Bool(false), ConstExpr::Bool(false)) => false, + ( + ConstExpr::None | ConstExpr::Ellipsis | ConstExpr::Bool(_), + _, + ) + | ( + _, + ConstExpr::None | ConstExpr::Ellipsis | ConstExpr::Bool(_), + ) => true, + _ => return None, + }, + ast::CmpOp::In | ast::CmpOp::NotIn => return None, + }; + if !result { + return Some(ConstExpr::Bool(false)); + } + left_value = right_value; + } + Some(ConstExpr::Bool(true)) + } + _ if is_if_type_checking(node) => Some(ConstExpr::Bool(true)), + _ => None, + } } + + Some(resolve_const_expr(node)?.truthiness()) } let expression = self.add_standalone_expression(predicate_node); @@ -1955,14 +2181,14 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { if let Some(msg) = msg { let post_test = self.flow_snapshot(); let negated_predicate = predicate.negated(); - self.record_narrowing_constraint(negated_predicate); - self.record_reachability_constraint(negated_predicate); + let predicate_id = self.record_narrowing_constraint(negated_predicate); + self.record_reachability_constraint_id(predicate_id); self.visit_expr(msg); self.flow_restore(post_test); } - self.record_narrowing_constraint(predicate); - self.record_reachability_constraint(predicate); + let predicate_id = self.record_narrowing_constraint(predicate); + self.record_reachability_constraint_id(predicate_id); } ast::Stmt::Assign(node) => { @@ -2080,7 +2306,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { let (mut last_predicate, mut last_narrowing_id) = self.record_expression_narrowing_constraint(&node.test); let mut last_reachability_constraint = - self.record_reachability_constraint(last_predicate); + self.record_reachability_constraint_id(last_narrowing_id); let is_outer_block_in_type_checking = self.in_type_checking_block; @@ -2131,7 +2357,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { self.record_expression_narrowing_constraint(elif_test); last_reachability_constraint = - self.record_reachability_constraint(last_predicate); + self.record_reachability_constraint_id(last_narrowing_id); } // Determine if this clause is in type checking context @@ -2195,7 +2421,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { // after the loop. let pre_loop = self.flow_snapshot(); let (predicate, predicate_id) = self.record_expression_narrowing_constraint(test); - self.record_reachability_constraint(predicate); + self.record_reachability_constraint_id(predicate_id); let outer_loop = self.push_loop(); self.visit_body(body); @@ -2375,36 +2601,25 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { ); previous_pattern = Some(match_pattern_predicate); let reachability_constraint = - self.record_reachability_constraint(match_predicate); + self.record_reachability_constraint_id(match_narrowing_id); let match_success_guard_failure = case.guard.as_ref().map(|guard| { - let guard_expr = self.add_standalone_expression(guard); - // We could also add the guard expression as a reachability constraint, but - // it seems unlikely that both the case predicate as well as the guard are - // statically known conditions, so we currently don't model that. - self.record_ambiguous_reachability(); self.visit_expr(guard); let post_guard_eval = self.flow_snapshot(); - let predicate = PredicateOrLiteral::Predicate(Predicate { - node: PredicateNode::Expression(guard_expr), - is_positive: true, - }); - // Add the predicate once, then use TDD-level negation for the failure - // path. This ensures the positive and negative atoms share the same ID. - let guard_predicate_id = self.add_predicate(predicate); - let possibly_narrowed = self.compute_possibly_narrowed_places(&predicate); - self.current_use_def_map_mut() - .record_negated_narrowing_constraint_for_places( - guard_predicate_id, - &possibly_narrowed, - ); - let match_success_guard_failure = self.flow_snapshot(); + let (guard_predicate, guard_predicate_id) = + self.record_expression_narrowing_constraint(guard); + let guard_reachability_constraint = + self.record_reachability_constraint_id(guard_predicate_id); + + let guard_success_state = self.flow_snapshot(); self.flow_restore(post_guard_eval); - self.current_use_def_map_mut() - .record_narrowing_constraint_for_places( - guard_predicate_id, - &possibly_narrowed, - ); + self.record_negated_narrowing_constraint( + guard_predicate, + guard_predicate_id, + ); + self.record_negated_reachability_constraint(guard_reachability_constraint); + let match_success_guard_failure = self.flow_snapshot(); + self.flow_restore(guard_success_state); match_success_guard_failure }); @@ -2963,7 +3178,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { self.visit_expr(test); let pre_if = self.flow_snapshot(); let (predicate, predicate_id) = self.record_expression_narrowing_constraint(test); - let reachability_constraint = self.record_reachability_constraint(predicate); + let reachability_constraint = self.record_reachability_constraint_id(predicate_id); self.visit_expr(body); let post_body = self.flow_snapshot(); self.flow_restore(pre_if); diff --git a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs index 3776218e8d4f3..f0c141761c363 100644 --- a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs @@ -208,6 +208,9 @@ use crate::semantic_index::predicate::{ CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, Predicates, ScopedPredicateId, }; +use crate::semantic_index::use_def::{ + PlaceVersion, PredicatePlaceVersionInfo, PredicatePlaceVersions, +}; use crate::types::{ CallableTypes, IntersectionBuilder, NarrowingConstraint, Truthiness, Type, TypeContext, UnionBuilder, UnionType, infer_expression_type, infer_narrowing_constraint, @@ -783,28 +786,64 @@ impl ReachabilityConstraints { /// - `ALWAYS_FALSE`: this path is impossible → Never /// /// The final result is the union of all path results. + #[expect(clippy::too_many_arguments)] pub(crate) fn narrow_by_constraint<'db>( &self, db: &'db dyn Db, predicates: &Predicates<'db>, + predicate_place_versions: &PredicatePlaceVersions, id: ScopedReachabilityConstraintId, base_ty: Type<'db>, place: ScopedPlaceId, + binding_place_version: Option, ) -> Type<'db> { - self.narrow_by_constraint_inner(db, predicates, id, base_ty, place, None) + let mut memo = FxHashMap::default(); + let mut truthiness_memo = FxHashMap::default(); + let redundant_union = self.narrow_by_constraint_inner( + db, + predicates, + predicate_place_versions, + id, + base_ty, + place, + binding_place_version, + None, + &mut memo, + &mut truthiness_memo, + ); + UnionBuilder::new(db) + .unpack_aliases(false) + .add(redundant_union) + .build() } /// Inner recursive helper that accumulates narrowing constraints along each TDD path. + #[allow(clippy::too_many_arguments)] fn narrow_by_constraint_inner<'db>( &self, db: &'db dyn Db, predicates: &Predicates<'db>, + predicate_place_versions: &PredicatePlaceVersions, id: ScopedReachabilityConstraintId, base_ty: Type<'db>, place: ScopedPlaceId, + binding_place_version: Option, accumulated: Option>, + memo: &mut FxHashMap< + ( + ScopedReachabilityConstraintId, + Option>, + ), + Type<'db>, + >, + truthiness_memo: &mut FxHashMap, Truthiness>, ) -> Type<'db> { - match id { + let key = (id, accumulated.clone()); + if let Some(cached) = memo.get(&key).copied() { + return cached; + } + + let narrowed = match id { ALWAYS_TRUE | AMBIGUOUS => { // Apply all accumulated narrowing constraints to the base type match accumulated { @@ -818,101 +857,97 @@ impl ReachabilityConstraints { _ => { let node = self.get_interior_node(id); let predicate = predicates[node.atom]; - - // `ReturnsNever` predicates don't narrow any variable; they only - // affect reachability. Evaluate the predicate to determine which - // path(s) are reachable, rather than walking both branches. - // `ReturnsNever` always evaluates to `AlwaysTrue` or `AlwaysFalse`, - // never `Ambiguous`. - if matches!(predicate.node, PredicateNode::ReturnsNever(_)) { - return match Self::analyze_single(db, &predicate) { - Truthiness::AlwaysTrue => self.narrow_by_constraint_inner( - db, - predicates, - node.if_true, - base_ty, - place, - accumulated, - ), - Truthiness::AlwaysFalse => self.narrow_by_constraint_inner( + macro_rules! narrow { + ($next_id:expr, $next_accumulated:expr) => { + self.narrow_by_constraint_inner( db, predicates, - node.if_false, + predicate_place_versions, + $next_id, base_ty, place, - accumulated, - ), - Truthiness::Ambiguous => { - unreachable!("ReturnsNever predicates should never be Ambiguous") - } + binding_place_version, + $next_accumulated, + memo, + truthiness_memo, + ) }; } // Check if this predicate narrows the variable we're interested in. - let pos_constraint = infer_narrowing_constraint(db, predicate, place); + let neg_predicate = Predicate { + node: predicate.node, + is_positive: !predicate.is_positive, + }; + let place_version_info = predicate_place_versions.get(&(node.atom, place)); + let can_apply_narrowing = place_version_info.is_some() + && Self::predicate_applies_to_place_version( + place_version_info, + binding_place_version, + ); + let (pos_constraint, neg_constraint) = if can_apply_narrowing { + ( + infer_narrowing_constraint(db, predicate, place), + infer_narrowing_constraint(db, neg_predicate, place), + ) + } else { + // No recorded place-version metadata means this predicate cannot narrow + // this place, or the narrowing belongs to a different place version. + // In either case, skip the expensive narrowing-inference queries. + (None, None) + }; + + // If this predicate does not narrow the current place and we can statically + // determine its truthiness, follow only the reachable branch. + if pos_constraint.is_none() && neg_constraint.is_none() { + match Self::analyze_single_cached(db, predicate, truthiness_memo) { + Truthiness::AlwaysTrue => { + return narrow!(node.if_true, accumulated); + } + Truthiness::AlwaysFalse => { + return narrow!(node.if_false, accumulated); + } + Truthiness::Ambiguous => {} + } + } // If the true branch is statically unreachable, skip it entirely. if node.if_true == ALWAYS_FALSE { - let neg_predicate = Predicate { - node: predicate.node, - is_positive: !predicate.is_positive, - }; - let neg_constraint = infer_narrowing_constraint(db, neg_predicate, place); let false_accumulated = accumulate_constraint(db, accumulated, neg_constraint); - return self.narrow_by_constraint_inner( - db, - predicates, - node.if_false, - base_ty, - place, - false_accumulated, - ); + return narrow!(node.if_false, false_accumulated); } // If the false branch is statically unreachable, skip it entirely. if node.if_false == ALWAYS_FALSE { let true_accumulated = accumulate_constraint(db, accumulated, pos_constraint); - return self.narrow_by_constraint_inner( - db, - predicates, - node.if_true, - base_ty, - place, - true_accumulated, - ); + return narrow!(node.if_true, true_accumulated); } // True branch: predicate holds → accumulate positive narrowing let true_accumulated = accumulate_constraint(db, accumulated.clone(), pos_constraint); - let true_ty = self.narrow_by_constraint_inner( - db, - predicates, - node.if_true, - base_ty, - place, - true_accumulated, - ); + let true_ty = narrow!(node.if_true, true_accumulated); // False branch: predicate doesn't hold → accumulate negative narrowing - let neg_predicate = Predicate { - node: predicate.node, - is_positive: !predicate.is_positive, - }; - let neg_constraint = infer_narrowing_constraint(db, neg_predicate, place); let false_accumulated = accumulate_constraint(db, accumulated, neg_constraint); - let false_ty = self.narrow_by_constraint_inner( - db, - predicates, - node.if_false, - base_ty, - place, - false_accumulated, - ); - - UnionType::from_elements(db, [true_ty, false_ty]) + let false_ty = narrow!(node.if_false, false_accumulated); + + // We won't do a union type redundancy check here, as it only needs to be performed once for the final result. + UnionType::from_elements_no_redundancy_check(db, [true_ty, false_ty]) } - } + }; + + memo.insert(key, narrowed); + narrowed + } + + fn predicate_applies_to_place_version( + place_version_info: Option<&PredicatePlaceVersionInfo>, + binding_place_version: Option, + ) -> bool { + binding_place_version.is_none_or(|binding_place_version| { + place_version_info.is_some_and(|info| info.versions.contains(&binding_place_version)) + }) } /// Analyze the statically known reachability for a given constraint. @@ -1149,4 +1184,18 @@ impl ReachabilityConstraints { } } } + + fn analyze_single_cached<'db>( + db: &'db dyn Db, + predicate: Predicate<'db>, + memo: &mut FxHashMap, Truthiness>, + ) -> Truthiness { + if let Some(cached) = memo.get(&predicate) { + return *cached; + } + + let analyzed = Self::analyze_single(db, &predicate); + memo.insert(predicate, analyzed); + analyzed + } } diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index c0597e87a8e51..36ecebfbe3de7 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -242,6 +242,7 @@ use ruff_index::{IndexVec, newtype_index}; use rustc_hash::FxHashMap; +use smallvec::SmallVec; use crate::node_key::NodeKey; use crate::place::BoundnessAnalysis; @@ -268,7 +269,15 @@ use crate::types::{PossiblyNarrowedPlaces, Truthiness, Type}; mod place_state; pub(super) use place_state::PreviousDefinitions; -pub(crate) use place_state::{LiveBinding, ScopedDefinitionId}; +pub(crate) use place_state::{LiveBinding, PlaceVersion, ScopedDefinitionId}; + +#[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update, get_size2::GetSize)] +pub(crate) struct PredicatePlaceVersionInfo { + pub(crate) versions: SmallVec<[PlaceVersion; 2]>, +} + +pub(crate) type PredicatePlaceVersions = + FxHashMap<(ScopedPredicateId, ScopedPlaceId), PredicatePlaceVersionInfo>; /// Applicable definitions and constraints for every use of a name. #[derive(Debug, PartialEq, Eq, salsa::Update, get_size2::GetSize)] @@ -280,6 +289,15 @@ pub(crate) struct UseDefMap<'db> { /// Array of predicates in this scope. predicates: Predicates<'db>, + /// Place version associated with each definition ID. + /// + /// This stores the version once per definition instead of duplicating it in every `LiveBinding` + /// clone across `bindings_by_use` / snapshots. + definition_place_versions: IndexVec, + + /// Place versions to which a given predicate occurrence can apply for narrowing. + predicate_place_versions: PredicatePlaceVersions, + /// Array of reachability constraints in this scope. reachability_constraints: ReachabilityConstraints, @@ -373,7 +391,9 @@ impl<'db> UseDefMap<'db> { ApplicableConstraints::UnboundBinding(NarrowingEvaluator { constraint, predicates: &self.predicates, + predicate_place_versions: &self.predicate_place_versions, reachability_constraints: &self.reachability_constraints, + binding_place_version: None, }) } ConstraintKey::NestedScope(nested_scope) => { @@ -413,7 +433,9 @@ impl<'db> UseDefMap<'db> { NarrowingEvaluator { constraint, predicates: &self.predicates, + predicate_place_versions: &self.predicate_place_versions, reachability_constraints: &self.reachability_constraints, + binding_place_version: None, } } @@ -654,7 +676,9 @@ impl<'db> UseDefMap<'db> { ) -> BindingWithConstraintsIterator<'map, 'db> { BindingWithConstraintsIterator { all_definitions: &self.all_definitions, + definition_place_versions: &self.definition_place_versions, predicates: &self.predicates, + predicate_place_versions: &self.predicate_place_versions, reachability_constraints: &self.reachability_constraints, boundness_analysis, inner: bindings.iter(), @@ -709,7 +733,9 @@ type EnclosingSnapshots = IndexVec #[derive(Clone, Debug)] pub(crate) struct BindingWithConstraintsIterator<'map, 'db> { pub(crate) all_definitions: &'map IndexVec>, + definition_place_versions: &'map IndexVec, pub(crate) predicates: &'map Predicates<'db>, + pub(crate) predicate_place_versions: &'map PredicatePlaceVersions, pub(crate) reachability_constraints: &'map ReachabilityConstraints, pub(crate) boundness_analysis: BoundnessAnalysis, inner: LiveBindingsIterator<'map>, @@ -720,6 +746,7 @@ impl<'map, 'db> Iterator for BindingWithConstraintsIterator<'map, 'db> { fn next(&mut self) -> Option { let predicates = self.predicates; + let predicate_place_versions = self.predicate_place_versions; let reachability_constraints = self.reachability_constraints; self.inner @@ -729,7 +756,11 @@ impl<'map, 'db> Iterator for BindingWithConstraintsIterator<'map, 'db> { narrowing_constraint: NarrowingEvaluator { constraint: live_binding.narrowing_constraint, predicates, + predicate_place_versions, reachability_constraints, + binding_place_version: Some( + self.definition_place_versions[live_binding.binding], + ), }, reachability_constraint: live_binding.reachability_constraint, }) @@ -747,7 +778,9 @@ pub(crate) struct BindingWithConstraints<'map, 'db> { pub(crate) struct NarrowingEvaluator<'map, 'db> { pub(crate) constraint: ScopedNarrowingConstraint, predicates: &'map Predicates<'db>, + predicate_place_versions: &'map PredicatePlaceVersions, reachability_constraints: &'map ReachabilityConstraints, + binding_place_version: Option, } impl<'db> NarrowingEvaluator<'_, 'db> { @@ -760,9 +793,11 @@ impl<'db> NarrowingEvaluator<'_, 'db> { self.reachability_constraints.narrow_by_constraint( db, self.predicates, + self.predicate_place_versions, self.constraint, base_ty, place, + self.binding_place_version, ) } } @@ -828,9 +863,15 @@ pub(super) struct UseDefMapBuilder<'db> { /// Append-only array of [`DefinitionState`]. all_definitions: IndexVec>, + /// Place version associated with each definition ID. + definition_place_versions: IndexVec, + /// Builder of predicates. pub(super) predicates: PredicatesBuilder<'db>, + /// Place versions to which a given predicate occurrence can apply for narrowing. + predicate_place_versions: PredicatePlaceVersions, + /// Builder of reachability constraints. pub(super) reachability_constraints: ReachabilityConstraintsBuilder, @@ -872,7 +913,9 @@ impl<'db> UseDefMapBuilder<'db> { pub(super) fn new(is_class_scope: bool) -> Self { Self { all_definitions: IndexVec::from_iter([DefinitionState::Undefined]), + definition_place_versions: IndexVec::from_iter([PlaceVersion::default()]), predicates: PredicatesBuilder::default(), + predicate_place_versions: PredicatePlaceVersions::default(), reachability_constraints: ReachabilityConstraintsBuilder::default(), bindings_by_use: IndexVec::new(), reachability: ScopedReachabilityConstraintId::ALWAYS_TRUE, @@ -959,13 +1002,15 @@ impl<'db> UseDefMapBuilder<'db> { self.declarations_by_binding .insert(binding, place_state.declarations().clone()); - place_state.record_binding( + let place_version = place_state.record_binding( def_id, self.reachability, self.is_class_scope, place.is_symbol(), previous_definitions, ); + let version_id = self.definition_place_versions.push(place_version); + debug_assert_eq!(def_id, version_id); let bindings = match place { ScopedPlaceId::Symbol(symbol) => { @@ -1009,6 +1054,8 @@ impl<'db> UseDefMapBuilder<'db> { return; } + self.record_predicate_place_versions(predicate, places); + let atom = self.reachability_constraints.add_atom(predicate); self.record_narrowing_constraint_node_for_places(atom, places); } @@ -1030,11 +1077,46 @@ impl<'db> UseDefMapBuilder<'db> { return; } + self.record_predicate_place_versions(predicate, places); + let atom = self.reachability_constraints.add_atom(predicate); let negated = self.reachability_constraints.add_not_constraint(atom); self.record_narrowing_constraint_node_for_places(negated, places); } + fn record_predicate_place_versions( + &mut self, + predicate: ScopedPredicateId, + places: &PossiblyNarrowedPlaces, + ) { + for place in places { + let bindings = match place { + ScopedPlaceId::Symbol(symbol_id) => { + self.symbol_states.get(*symbol_id).map(PlaceState::bindings) + } + ScopedPlaceId::Member(member_id) => { + self.member_states.get(*member_id).map(PlaceState::bindings) + } + }; + let Some(bindings) = bindings else { + continue; + }; + + let versions = bindings + .iter() + .map(|binding| self.definition_place_versions[binding.binding]); + let entry = self + .predicate_place_versions + .entry((predicate, *place)) + .or_default(); + for version in versions { + if !entry.versions.contains(&version) { + entry.versions.push(version); + } + } + } + } + /// Records a TDD narrowing constraint node for the specified places. fn record_narrowing_constraint_node_for_places( &mut self, @@ -1202,6 +1284,8 @@ impl<'db> UseDefMapBuilder<'db> { let def_id = self .all_definitions .push(DefinitionState::Defined(declaration)); + let version_id = self.definition_place_versions.push(PlaceVersion::default()); + debug_assert_eq!(def_id, version_id); let place_state = match place { ScopedPlaceId::Symbol(symbol) => &mut self.symbol_states[symbol], @@ -1239,13 +1323,15 @@ impl<'db> UseDefMapBuilder<'db> { ScopedPlaceId::Member(member) => &mut self.member_states[member], }; place_state.record_declaration(def_id, self.reachability); - place_state.record_binding( + let place_version = place_state.record_binding( def_id, self.reachability, self.is_class_scope, place.is_symbol(), PreviousDefinitions::AreShadowed, ); + let version_id = self.definition_place_versions.push(place_version); + debug_assert_eq!(def_id, version_id); let reachable_definitions = match place { ScopedPlaceId::Symbol(symbol) => &mut self.reachable_symbol_definitions[symbol], @@ -1272,14 +1358,15 @@ impl<'db> UseDefMapBuilder<'db> { ScopedPlaceId::Symbol(symbol) => &mut self.symbol_states[symbol], ScopedPlaceId::Member(member) => &mut self.member_states[member], }; - - place_state.record_binding( + let place_version = place_state.record_binding( def_id, self.reachability, self.is_class_scope, place.is_symbol(), PreviousDefinitions::AreShadowed, ); + let version_id = self.definition_place_versions.push(place_version); + debug_assert_eq!(def_id, version_id); } pub(super) fn record_use( @@ -1504,11 +1591,13 @@ impl<'db> UseDefMapBuilder<'db> { self.mark_reachability_constraints(); self.all_definitions.shrink_to_fit(); + self.definition_place_versions.shrink_to_fit(); self.symbol_states.shrink_to_fit(); self.member_states.shrink_to_fit(); self.reachable_symbol_definitions.shrink_to_fit(); self.reachable_member_definitions.shrink_to_fit(); self.bindings_by_use.shrink_to_fit(); + self.predicate_place_versions.shrink_to_fit(); self.node_reachability.shrink_to_fit(); self.declarations_by_binding.shrink_to_fit(); self.bindings_by_definition.shrink_to_fit(); @@ -1517,6 +1606,8 @@ impl<'db> UseDefMapBuilder<'db> { UseDefMap { all_definitions: self.all_definitions, predicates: self.predicates.build(), + definition_place_versions: self.definition_place_versions, + predicate_place_versions: self.predicate_place_versions, reachability_constraints: self.reachability_constraints.build(), bindings_by_use: self.bindings_by_use, node_reachability: self.node_reachability, diff --git a/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs b/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs index 71833f3406397..a3ff20a3c15d0 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs @@ -69,6 +69,29 @@ impl ScopedDefinitionId { } } +/// A monotonically increasing place generation. +/// +/// The generation increments whenever bindings for a place are shadowed by reassignment. +#[newtype_index] +#[derive(Ord, PartialOrd, salsa::Update, get_size2::GetSize)] +pub(crate) struct PlaceVersion; + +impl Default for PlaceVersion { + fn default() -> Self { + PlaceVersion::from_u32(0) + } +} + +impl PlaceVersion { + pub(crate) fn next(self) -> PlaceVersion { + let next = self + .as_u32() + .checked_add(1) + .expect("PlaceVersion overflowed"); + PlaceVersion::from_u32(next) + } +} + /// Live declarations for a single place at some point in control flow, with their /// corresponding reachability constraints. #[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update, get_size2::GetSize)] @@ -213,7 +236,10 @@ pub(super) struct Bindings { /// "unbound" binding. unbound_narrowing_constraint: Option, /// A list of live bindings for this place, sorted by their `ScopedDefinitionId` + #[allow(clippy::struct_field_names)] live_bindings: SmallVec<[LiveBinding; 2]>, + /// Latest place version seen for this place. + latest_place_version: PlaceVersion, } impl Bindings { @@ -251,6 +277,7 @@ impl Bindings { Self { unbound_narrowing_constraint: None, live_bindings: smallvec![initial_binding], + latest_place_version: PlaceVersion::default(), } } @@ -262,7 +289,7 @@ impl Bindings { is_class_scope: bool, is_place_name: bool, previous_definitions: PreviousDefinitions, - ) { + ) -> PlaceVersion { // If we are in a class scope, and the unbound name binding was previously visible, but we will // now replace it, record the narrowing constraints on it: if is_class_scope && is_place_name && self.live_bindings[0].binding.is_unbound() { @@ -272,12 +299,14 @@ impl Bindings { // constraints. if previous_definitions.are_shadowed() { self.live_bindings.clear(); + self.latest_place_version = self.latest_place_version.next(); } self.live_bindings.push(LiveBinding { binding, narrowing_constraint: ScopedNarrowingConstraint::ALWAYS_TRUE, reachability_constraint, }); + self.latest_place_version } /// Add given constraint to all live bindings. @@ -315,6 +344,7 @@ impl Bindings { reachability_constraints: &mut ReachabilityConstraintsBuilder, ) { let a = std::mem::take(self); + self.latest_place_version = a.latest_place_version.max(b.latest_place_version); if let Some((a, b)) = a .unbound_narrowing_constraint @@ -334,15 +364,29 @@ impl Bindings { for zipped in a.merge_join_by(b, |a, b| a.binding.cmp(&b.binding)) { match zipped { EitherOrBoth::Both(a, b) => { - // If the same definition is visible through both paths, we OR the narrowing - // constraints: the type should be narrowed by whichever path was taken. - let narrowing_constraint = reachability_constraints - .add_or_constraint(a.narrowing_constraint, b.narrowing_constraint); - // For reachability constraints, we also merge using a ternary OR operation: let reachability_constraint = reachability_constraints .add_or_constraint(a.reachability_constraint, b.reachability_constraint); + let narrowing_constraint = if a.narrowing_constraint + == ScopedNarrowingConstraint::ALWAYS_TRUE + && b.narrowing_constraint == ScopedNarrowingConstraint::ALWAYS_TRUE + { + // short-circuit: if both sides are ALWAYS_TRUE, the result is ALWAYS_TRUE without needing to create a new TDD node. + ScopedNarrowingConstraint::ALWAYS_TRUE + } else { + // A branch contributes narrowing only when it is reachable. + // Without this gating, `OR(a_narrowing, b_narrowing)` allows an unreachable + // branch with `ALWAYS_TRUE` narrowing to cancel useful narrowing from the + // reachable branch. + let a_narrowing_gated = reachability_constraints + .add_and_constraint(a.narrowing_constraint, a.reachability_constraint); + let b_narrowing_gated = reachability_constraints + .add_and_constraint(b.narrowing_constraint, b.reachability_constraint); + reachability_constraints + .add_or_constraint(a_narrowing_gated, b_narrowing_gated) + }; + self.live_bindings.push(LiveBinding { binding: a.binding, narrowing_constraint, @@ -381,7 +425,7 @@ impl PlaceState { is_class_scope: bool, is_place_name: bool, previous_definitions: PreviousDefinitions, - ) { + ) -> PlaceVersion { debug_assert_ne!(binding_id, ScopedDefinitionId::UNBOUND); self.bindings.record_binding( binding_id, @@ -389,7 +433,7 @@ impl PlaceState { is_class_scope, is_place_name, previous_definitions, - ); + ) } /// Add given constraint to all live bindings. @@ -636,6 +680,31 @@ mod tests { assert_eq!(bindings[1].1, atom0); assert_eq!(bindings[2].0, 3); assert_eq!(bindings[2].1, atom3); + + // An unreachable branch should not dilute narrowing from the reachable branch. + let mut sym4a = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); + sym4a.record_binding( + ScopedDefinitionId::from_u32(4), + ScopedReachabilityConstraintId::ALWAYS_FALSE, + false, + true, + PreviousDefinitions::AreShadowed, + ); + + let mut sym4b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); + sym4b.record_binding( + ScopedDefinitionId::from_u32(4), + ScopedReachabilityConstraintId::ALWAYS_TRUE, + false, + true, + PreviousDefinitions::AreShadowed, + ); + let atom4 = reachability_constraints.add_atom(ScopedPredicateId::new(4)); + sym4b.record_narrowing_constraint(&mut reachability_constraints, atom4); + + sym4a.merge(sym4b, &mut reachability_constraints); + let merged_constraint = sym4a.bindings().iter().next().unwrap().narrowing_constraint; + assert_eq!(merged_constraint, atom4); } #[test] diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 043f23ca36ba3..c31f9e5a4c819 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -12257,6 +12257,20 @@ impl<'db> UnionType<'db> { .build() } + pub(crate) fn from_elements_no_redundancy_check(db: &'db dyn Db, elements: I) -> Type<'db> + where + I: IntoIterator, + T: Into>, + { + elements + .into_iter() + .fold( + UnionBuilder::new(db).check_redundancy(false), + |builder, element| builder.add(element.into()), + ) + .build() + } + /// Create a union from a list of elements without unpacking type aliases. pub(crate) fn from_elements_leave_aliases(db: &'db dyn Db, elements: I) -> Type<'db> where diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 1ba0388b32e6e..950b17d5515e0 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -242,11 +242,13 @@ const MAX_NON_RECURSIVE_UNION_LITERALS: usize = 256; /// if reachability analysis etc. fails when analysing these enums. const MAX_NON_RECURSIVE_UNION_ENUM_LITERALS: usize = 8192; +#[allow(clippy::struct_excessive_bools)] pub(crate) struct UnionBuilder<'db> { elements: Vec>, db: &'db dyn Db, unpack_aliases: bool, order_elements: bool, + check_redundancy: bool, /// This is enabled when joining types in a `cycle_recovery` function. /// Since a cycle cannot be created within a `cycle_recovery` function, /// execution of `is_redundant_with` is skipped. @@ -261,6 +263,7 @@ impl<'db> UnionBuilder<'db> { elements: vec![], unpack_aliases: true, order_elements: false, + check_redundancy: true, cycle_recovery: false, recursively_defined: RecursivelyDefined::No, } @@ -276,9 +279,15 @@ impl<'db> UnionBuilder<'db> { self } + pub(crate) fn check_redundancy(mut self, val: bool) -> Self { + self.check_redundancy = val; + self + } + pub(crate) fn cycle_recovery(mut self, val: bool) -> Self { self.cycle_recovery = val; if self.cycle_recovery { + self.check_redundancy = false; self.unpack_aliases = false; } self @@ -622,7 +631,7 @@ impl<'db> UnionBuilder<'db> { // If an alias gets here, it means we aren't unpacking aliases, and we also // shouldn't try to simplify aliases out of the union, because that will require // unpacking them. - let should_simplify_full = !matches!(ty, Type::TypeAlias(_)) && !self.cycle_recovery; + let should_simplify_full = !matches!(ty, Type::TypeAlias(_)) && self.check_redundancy; let mut ty_negated: Option = None; let mut to_remove = SmallVec::<[usize; 2]>::new(); diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 4d62abcbf0bbd..b7e1fbd8068a1 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -62,70 +62,133 @@ pub(crate) fn infer_narrowing_constraint<'db>( ) -> Option> { let constraints = match predicate.node { PredicateNode::Expression(expression) => { - if predicate.is_positive { - all_narrowing_constraints_for_expression(db, expression) - } else { - all_negative_narrowing_constraints_for_expression(db, expression) - } - } - PredicateNode::Pattern(pattern) => { - if predicate.is_positive { - all_narrowing_constraints_for_pattern(db, pattern) - } else { - all_negative_narrowing_constraints_for_pattern(db, pattern) - } + all_narrowing_constraints_for_expression(db, expression) } + PredicateNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern), PredicateNode::ReturnsNever(_) => return None, PredicateNode::StarImportPlaceholder(_) => return None, }; - constraints.and_then(|constraints| constraints.get(&place).cloned()) + constraints.and_then(|constraints| constraints.get(place, predicate.is_positive).cloned()) } -#[salsa::tracked(returns(as_ref), heap_size=ruff_memory_usage::heap_size)] -fn all_narrowing_constraints_for_pattern<'db>( - db: &'db dyn Db, - pattern: PatternPredicate<'db>, -) -> Option> { - let module = parsed_module(db, pattern.file(db)).load(db); - NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), true).finish() +#[derive(Default, PartialEq, Debug, Eq, Clone, salsa::Update, get_size2::GetSize)] +struct PerPlaceDualNarrowingConstraint<'db> { + positive: Option>, + negative: Option>, } -#[salsa::tracked( - returns(as_ref), - cycle_initial=|_, _, _| None, - heap_size=ruff_memory_usage::heap_size, -)] -fn all_narrowing_constraints_for_expression<'db>( - db: &'db dyn Db, - expression: Expression<'db>, -) -> Option> { - let module = parsed_module(db, expression.file(db)).load(db); - NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), true) - .finish() +type DualNarrowingConstraintsMap<'db> = + FxHashMap>; + +#[derive(Default, PartialEq, Debug, Eq, Clone, salsa::Update, get_size2::GetSize)] +struct DualNarrowingConstraints<'db> { + by_place: DualNarrowingConstraintsMap<'db>, + has_positive: bool, + has_negative: bool, +} + +impl<'db> DualNarrowingConstraints<'db> { + fn from_sides( + positive: Option>, + negative: Option>, + ) -> Self { + let mut by_place = DualNarrowingConstraintsMap::default(); + let has_positive = positive.is_some(); + let has_negative = negative.is_some(); + + if let Some(positive) = positive { + for (place, constraint) in positive { + by_place.entry(place).or_default().positive = Some(constraint); + } + } + + if let Some(negative) = negative { + for (place, constraint) in negative { + by_place.entry(place).or_default().negative = Some(constraint); + } + } + + Self { + by_place, + has_positive, + has_negative, + } + } + + fn into_sides( + self, + ) -> ( + Option>, + Option>, + ) { + let mut positive = self.has_positive.then(FxHashMap::default); + let mut negative = self.has_negative.then(FxHashMap::default); + + for (place, constraints) in self.by_place { + if let (Some(positive), Some(constraint)) = (&mut positive, constraints.positive) { + positive.insert(place, constraint); + } + if let (Some(negative), Some(constraint)) = (&mut negative, constraints.negative) { + negative.insert(place, constraint); + } + } + + (positive, negative) + } + + fn get(&self, place: ScopedPlaceId, is_positive: bool) -> Option<&NarrowingConstraint<'db>> { + if is_positive && !self.has_positive || !is_positive && !self.has_negative { + return None; + } + + self.by_place.get(&place).and_then(|constraints| { + if is_positive { + constraints.positive.as_ref() + } else { + constraints.negative.as_ref() + } + }) + } + + fn shrink_to_fit(&mut self) { + self.by_place.shrink_to_fit(); + } + + fn swap_polarity(mut self) -> Self { + std::mem::swap(&mut self.has_positive, &mut self.has_negative); + for constraints in self.by_place.values_mut() { + std::mem::swap(&mut constraints.positive, &mut constraints.negative); + } + self + } } +#[allow(clippy::unnecessary_wraps)] #[salsa::tracked( returns(as_ref), cycle_initial=|_, _, _| None, heap_size=ruff_memory_usage::heap_size, )] -fn all_negative_narrowing_constraints_for_expression<'db>( +fn all_narrowing_constraints_for_expression<'db>( db: &'db dyn Db, expression: Expression<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, expression.file(db)).load(db); - NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), false) - .finish() + Some( + NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression)) + .finish(), + ) } +#[allow(clippy::unnecessary_wraps)] #[salsa::tracked(returns(as_ref), heap_size=ruff_memory_usage::heap_size)] -fn all_negative_narrowing_constraints_for_pattern<'db>( +fn all_narrowing_constraints_for_pattern<'db>( db: &'db dyn Db, pattern: PatternPredicate<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, pattern.file(db)).load(db); - NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), false).finish() + Some(NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern)).finish()) } /// Functions that can be used to narrow the type of a first argument using a "classinfo" second argument. @@ -495,74 +558,94 @@ struct NarrowingConstraintsBuilder<'db, 'ast> { db: &'db dyn Db, module: &'ast ParsedModuleRef, predicate: PredicateNode<'db>, - is_positive: bool, } impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { - fn new( - db: &'db dyn Db, - module: &'ast ParsedModuleRef, - predicate: PredicateNode<'db>, - is_positive: bool, - ) -> Self { + fn new(db: &'db dyn Db, module: &'ast ParsedModuleRef, predicate: PredicateNode<'db>) -> Self { Self { db, module, predicate, - is_positive, } } - fn finish(mut self) -> Option> { - let mut constraints: Option> = match self.predicate { - PredicateNode::Expression(expression) => { - self.evaluate_expression_predicate(expression, self.is_positive) - } - PredicateNode::Pattern(pattern) => { - self.evaluate_pattern_predicate(pattern, self.is_positive) + fn finish(mut self) -> DualNarrowingConstraints<'db> { + let mut constraints = match self.predicate { + PredicateNode::Expression(expression) => self.evaluate_expression_predicate(expression), + PredicateNode::Pattern(pattern) => self.evaluate_pattern_predicate(pattern), + PredicateNode::ReturnsNever(_) | PredicateNode::StarImportPlaceholder(_) => { + return DualNarrowingConstraints::default(); } - PredicateNode::ReturnsNever(_) => return None, - PredicateNode::StarImportPlaceholder(_) => return None, }; - if let Some(ref mut constraints) = constraints { - constraints.shrink_to_fit(); - } + constraints.shrink_to_fit(); constraints } + fn merge_constraints_and_sequence( + &self, + sub_constraints: Vec>>, + ) -> Option> { + let mut aggregation: Option> = None; + for sub_constraint in sub_constraints.into_iter().flatten() { + if let Some(ref mut some_aggregation) = aggregation { + merge_constraints_and(some_aggregation, sub_constraint, self.db); + } else { + aggregation = Some(sub_constraint); + } + } + aggregation + } + + fn merge_constraints_or_sequence( + &self, + sub_constraints: Vec>>, + ) -> Option> { + let (mut first, rest) = { + let mut it = sub_constraints.into_iter(); + (it.next()?, it) + }; + + if let Some(ref mut first) = first { + for rest_constraint in rest { + if let Some(rest_constraint) = rest_constraint { + merge_constraints_or(first, rest_constraint, self.db); + } else { + return None; + } + } + } + first + } + fn evaluate_expression_predicate( &mut self, expression: Expression<'db>, - is_positive: bool, - ) -> Option> { + ) -> DualNarrowingConstraints<'db> { let expression_node = expression.node_ref(self.db, self.module); - self.evaluate_expression_node_predicate(expression_node, expression, is_positive) + self.evaluate_expression_node_predicate(expression_node, expression) } fn evaluate_expression_node_predicate( &mut self, expression_node: &ruff_python_ast::Expr, expression: Expression<'db>, - is_positive: bool, - ) -> Option> { + ) -> DualNarrowingConstraints<'db> { match expression_node { ast::Expr::Name(_) | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) => { - self.evaluate_simple_expr(expression_node, is_positive) + self.evaluate_simple_expr(expression_node) } ast::Expr::Compare(expr_compare) => { - self.evaluate_expr_compare(expr_compare, expression, is_positive) - } - ast::Expr::Call(expr_call) => { - self.evaluate_expr_call(expr_call, expression, is_positive) - } - ast::Expr::UnaryOp(unary_op) if unary_op.op == ast::UnaryOp::Not => { - self.evaluate_expression_node_predicate(&unary_op.operand, expression, !is_positive) + self.evaluate_expr_compare(expr_compare, expression) } - ast::Expr::BoolOp(bool_op) => self.evaluate_bool_op(bool_op, expression, is_positive), - ast::Expr::Named(expr_named) => self.evaluate_expr_named(expr_named, is_positive), - _ => None, + ast::Expr::Call(expr_call) => self.evaluate_expr_call(expr_call, expression), + ast::Expr::UnaryOp(unary_op) if unary_op.op == ast::UnaryOp::Not => self + .evaluate_expression_node_predicate(&unary_op.operand, expression) + .swap_polarity(), + ast::Expr::BoolOp(bool_op) => self.evaluate_bool_op(bool_op, expression), + ast::Expr::Named(expr_named) => self.evaluate_expr_named(expr_named), + _ => DualNarrowingConstraints::default(), } } @@ -570,38 +653,32 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, pattern_predicate_kind: &PatternPredicateKind<'db>, subject: Expression<'db>, - is_positive: bool, - ) -> Option> { + ) -> DualNarrowingConstraints<'db> { match pattern_predicate_kind { PatternPredicateKind::Singleton(singleton) => { - self.evaluate_match_pattern_singleton(subject, *singleton, is_positive) + self.evaluate_match_pattern_singleton(subject, *singleton) } PatternPredicateKind::Class(cls, kind) => { - self.evaluate_match_pattern_class(subject, *cls, *kind, is_positive) - } - PatternPredicateKind::Value(expr) => { - self.evaluate_match_pattern_value(subject, *expr, is_positive) + self.evaluate_match_pattern_class(subject, *cls, *kind) } + PatternPredicateKind::Value(expr) => self.evaluate_match_pattern_value(subject, *expr), PatternPredicateKind::Or(predicates) => { - self.evaluate_match_pattern_or(subject, predicates, is_positive) + self.evaluate_match_pattern_or(subject, predicates) } PatternPredicateKind::As(pattern, _) => pattern .as_deref() - .and_then(|p| self.evaluate_pattern_predicate_kind(p, subject, is_positive)), - PatternPredicateKind::Unsupported => None, + .map_or_else(DualNarrowingConstraints::default, |p| { + self.evaluate_pattern_predicate_kind(p, subject) + }), + PatternPredicateKind::Unsupported => DualNarrowingConstraints::default(), } } fn evaluate_pattern_predicate( &mut self, pattern: PatternPredicate<'db>, - is_positive: bool, - ) -> Option> { - self.evaluate_pattern_predicate_kind( - pattern.kind(self.db), - pattern.subject(self.db), - is_positive, - ) + ) -> DualNarrowingConstraints<'db> { + self.evaluate_pattern_predicate_kind(pattern.kind(self.db), pattern.subject(self.db)) } fn places(&self) -> &'db PlaceTable { @@ -713,32 +790,29 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { } } - fn evaluate_simple_expr( - &mut self, - expr: &ast::Expr, - is_positive: bool, - ) -> Option> { - let target = PlaceExpr::try_from_expr(expr)?; - let place = self.expect_place(&target); - - let ty = if is_positive { - Type::AlwaysFalsy.negate(self.db) - } else { - Type::AlwaysTruthy.negate(self.db) + fn evaluate_simple_expr(&mut self, expr: &ast::Expr) -> DualNarrowingConstraints<'db> { + let Some(target) = PlaceExpr::try_from_expr(expr) else { + return DualNarrowingConstraints::default(); }; + let place = self.expect_place(&target); - Some(NarrowingConstraints::from_iter([( + let positive = NarrowingConstraints::from_iter([( place, - NarrowingConstraint::intersection(ty), - )])) + NarrowingConstraint::intersection(Type::AlwaysFalsy.negate(self.db)), + )]); + let negative = NarrowingConstraints::from_iter([( + place, + NarrowingConstraint::intersection(Type::AlwaysTruthy.negate(self.db)), + )]); + + DualNarrowingConstraints::from_sides(Some(positive), Some(negative)) } fn evaluate_expr_named( &mut self, expr_named: &ast::ExprNamed, - is_positive: bool, - ) -> Option> { - self.evaluate_simple_expr(&expr_named.target, is_positive) + ) -> DualNarrowingConstraints<'db> { + self.evaluate_simple_expr(&expr_named.target) } fn evaluate_expr_eq(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option> { @@ -948,10 +1022,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { lhs_ty: Type<'db>, rhs_ty: Type<'db>, op: ast::CmpOp, - is_positive: bool, ) -> Option> { - let op = if is_positive { op } else { op.negate() }; - // `Divergent` shows up as an initial value in cycle recovery. If it appears on either side // of a potentially narrowing comparison, we don't want it to turn that comparison into a // no-op (e.g. because `Divergent` is not a singleton in the `IsNot` branch below), because @@ -996,6 +1067,18 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, expr_compare: &ast::ExprCompare, expression: Expression<'db>, + ) -> DualNarrowingConstraints<'db> { + let inference = infer_expression_types(self.db, expression, TypeContext::default()); + DualNarrowingConstraints::from_sides( + self.evaluate_expr_compare_for_polarity(expr_compare, inference, true), + self.evaluate_expr_compare_for_polarity(expr_compare, inference, false), + ) + } + + fn evaluate_expr_compare_for_polarity( + &mut self, + expr_compare: &ast::ExprCompare, + inference: &ExpressionInference<'db>, is_positive: bool, ) -> Option> { fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool { @@ -1070,8 +1153,6 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { return None; } - let inference = infer_expression_types(self.db, expression, TypeContext::default()); - let comparator_tuples = std::iter::once(&**left) .chain(comparators) .tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>(); @@ -1313,7 +1394,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { // - `if x not in y` if narrowable_ast(left) && let Some(narrowable) = PlaceExpr::try_from_expr(left) - && let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive) + && let Some(ty) = self.evaluate_expr_compare_op( + lhs_ty, + rhs_ty, + if is_positive { *op } else { op.negate() }, + ) { let place = self.expect_place(&narrowable); let constraint = NarrowingConstraint::intersection(ty); @@ -1335,7 +1420,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { if !matches!(op, ast::CmpOp::In | ast::CmpOp::NotIn) && narrowable_ast(right) && let Some(narrowable) = PlaceExpr::try_from_expr(right) - && let Some(ty) = self.evaluate_expr_compare_op(rhs_ty, lhs_ty, *op, is_positive) + && let Some(ty) = self.evaluate_expr_compare_op( + rhs_ty, + lhs_ty, + if is_positive { *op } else { op.negate() }, + ) { let place = self.expect_place(&narrowable); let constraint = NarrowingConstraint::intersection(ty); @@ -1359,18 +1448,45 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, expr_call: &ast::ExprCall, expression: Expression<'db>, - is_positive: bool, - ) -> Option> { + ) -> DualNarrowingConstraints<'db> { let inference = infer_expression_types(self.db, expression, TypeContext::default()); - if let Some(type_guard_call_constraints) = - self.evaluate_type_guard_call(inference, expr_call, is_positive) + // If the return type of expr_call is TypeGuard (positive) / TypeIs: + if let Some(positive_constraints) = + self.evaluate_type_guard_call_for_polarity(inference, expr_call, true) { - return Some(type_guard_call_constraints); + let negative_constraints = + self.evaluate_type_guard_call_for_polarity(inference, expr_call, false); + return DualNarrowingConstraints::from_sides( + Some(positive_constraints), + negative_constraints, + ); } let callable_ty = inference.expression_type(&*expr_call.func); + if let Type::ClassLiteral(class_type) = callable_ty + && expr_call.arguments.args.len() == 1 + && expr_call.arguments.keywords.is_empty() + && class_type.is_known(self.db, KnownClass::Bool) + { + return self + .evaluate_expression_node_predicate(&expr_call.arguments.args[0], expression); + } + + DualNarrowingConstraints::from_sides( + self.evaluate_expr_call_for_polarity(expr_call, inference, callable_ty, true), + self.evaluate_expr_call_for_polarity(expr_call, inference, callable_ty, false), + ) + } + + fn evaluate_expr_call_for_polarity( + &mut self, + expr_call: &ast::ExprCall, + inference: &ExpressionInference<'db>, + callable_ty: Type<'db>, + is_positive: bool, + ) -> Option> { match callable_ty { // For the expression `len(E)`, we narrow the type based on whether len(E) is truthy // (i.e., whether E is non-empty). We only narrow the parts of the type where we know @@ -1442,25 +1558,13 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { )]) }) } - // for the expression `bool(E)`, we further narrow the type based on `E` - Type::ClassLiteral(class_type) - if expr_call.arguments.args.len() == 1 - && expr_call.arguments.keywords.is_empty() - && class_type.is_known(self.db, KnownClass::Bool) => - { - self.evaluate_expression_node_predicate( - &expr_call.arguments.args[0], - expression, - is_positive, - ) - } _ => None, } } // Helper to evaluate TypeGuard/TypeIs narrowing for a call expression. // This is based on the call expression's return type, so it applies to any callable type. - fn evaluate_type_guard_call( + fn evaluate_type_guard_call_for_polarity( &mut self, inference: &ExpressionInference<'db>, expr_call: &ast::ExprCall, @@ -1498,6 +1602,17 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, subject: Expression<'db>, singleton: ast::Singleton, + ) -> DualNarrowingConstraints<'db> { + DualNarrowingConstraints::from_sides( + self.evaluate_match_pattern_singleton_for_polarity(subject, singleton, true), + self.evaluate_match_pattern_singleton_for_polarity(subject, singleton, false), + ) + } + + fn evaluate_match_pattern_singleton_for_polarity( + &mut self, + subject: Expression<'db>, + singleton: ast::Singleton, is_positive: bool, ) -> Option> { let subject = PlaceExpr::try_from_expr(subject.node_ref(self.db, self.module))?; @@ -1520,6 +1635,18 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { subject: Expression<'db>, cls: Expression<'db>, kind: ClassPatternKind, + ) -> DualNarrowingConstraints<'db> { + DualNarrowingConstraints::from_sides( + self.evaluate_match_pattern_class_for_polarity(subject, cls, kind, true), + self.evaluate_match_pattern_class_for_polarity(subject, cls, kind, false), + ) + } + + fn evaluate_match_pattern_class_for_polarity( + &mut self, + subject: Expression<'db>, + cls: Expression<'db>, + kind: ClassPatternKind, is_positive: bool, ) -> Option> { if !kind.is_irrefutable() && !is_positive { @@ -1554,6 +1681,17 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, subject: Expression<'db>, value: Expression<'db>, + ) -> DualNarrowingConstraints<'db> { + DualNarrowingConstraints::from_sides( + self.evaluate_match_pattern_value_for_polarity(subject, value, true), + self.evaluate_match_pattern_value_for_polarity(subject, value, false), + ) + } + + fn evaluate_match_pattern_value_for_polarity( + &mut self, + subject: Expression<'db>, + value: Expression<'db>, is_positive: bool, ) -> Option> { let subject_node = subject.node_ref(self.db, self.module); @@ -1568,7 +1706,15 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module); let mut constraints = self - .evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive) + .evaluate_expr_compare_op( + subject_ty, + value_ty, + if is_positive { + ast::CmpOp::Eq + } else { + ast::CmpOp::NotEq + }, + ) .map(|ty| { NarrowingConstraints::from_iter([(place, NarrowingConstraint::intersection(ty))]) }) @@ -1616,39 +1762,45 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, subject: Expression<'db>, predicates: &Vec>, - is_positive: bool, - ) -> Option> { - let db = self.db; + ) -> DualNarrowingConstraints<'db> { + let mut positive: Option> = None; + let mut negative: Option> = None; + + for predicate in predicates { + let (sub_positive, sub_negative) = self + .evaluate_pattern_predicate_kind(predicate, subject) + .into_sides(); + + if let Some(sub_positive) = sub_positive { + if let Some(ref mut aggregated) = positive { + merge_constraints_or(aggregated, sub_positive, self.db); + } else { + positive = Some(sub_positive); + } + } - // DeMorgan's law---if the overall `or` is negated, we need to `and` the negated sub-constraints. - let merge_constraints = if is_positive { - merge_constraints_or - } else { - merge_constraints_and - }; + if let Some(sub_negative) = sub_negative { + if let Some(ref mut aggregated) = negative { + merge_constraints_and(aggregated, sub_negative, self.db); + } else { + negative = Some(sub_negative); + } + } + } - predicates - .iter() - .filter_map(|predicate| { - self.evaluate_pattern_predicate_kind(predicate, subject, is_positive) - }) - .reduce(|mut constraints, constraints_| { - merge_constraints(&mut constraints, constraints_, db); - constraints - }) + DualNarrowingConstraints::from_sides(positive, negative) } fn evaluate_bool_op( &mut self, expr_bool_op: &ExprBoolOp, expression: Expression<'db>, - is_positive: bool, - ) -> Option> { + ) -> DualNarrowingConstraints<'db> { let inference = infer_expression_types(self.db, expression, TypeContext::default()); let sub_constraints = expr_bool_op .values .iter() - // filter our arms with statically known truthiness + // Filter out arms with statically known truthiness. .filter(|expr| { inference.expression_type(*expr).bool(self.db) != match expr_bool_op.op { @@ -1656,40 +1808,27 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { BoolOp::Or => Truthiness::AlwaysFalse, } }) - .map(|sub_expr| { - self.evaluate_expression_node_predicate(sub_expr, expression, is_positive) - }) + .map(|sub_expr| self.evaluate_expression_node_predicate(sub_expr, expression)) .collect::>(); - match (expr_bool_op.op, is_positive) { - (BoolOp::And, true) | (BoolOp::Or, false) => { - let mut aggregation: Option = None; - for sub_constraint in sub_constraints.into_iter().flatten() { - if let Some(ref mut some_aggregation) = aggregation { - merge_constraints_and(some_aggregation, sub_constraint, self.db); - } else { - aggregation = Some(sub_constraint); - } - } - aggregation - } - (BoolOp::Or, true) | (BoolOp::And, false) => { - let (mut first, rest) = { - let mut it = sub_constraints.into_iter(); - (it.next()?, it) - }; - if let Some(ref mut first) = first { - for rest_constraint in rest { - if let Some(rest_constraint) = rest_constraint { - merge_constraints_or(first, rest_constraint, self.db); - } else { - return None; - } - } - } - first - } - } + let (positive_sub_constraints, negative_sub_constraints): (Vec<_>, Vec<_>) = + sub_constraints + .into_iter() + .map(DualNarrowingConstraints::into_sides) + .unzip(); + + let (positive, negative) = match expr_bool_op.op { + BoolOp::And => ( + self.merge_constraints_and_sequence(positive_sub_constraints), + self.merge_constraints_or_sequence(negative_sub_constraints), + ), + BoolOp::Or => ( + self.merge_constraints_or_sequence(positive_sub_constraints), + self.merge_constraints_and_sequence(negative_sub_constraints), + ), + }; + + DualNarrowingConstraints::from_sides(positive, negative) } /// Narrow tagged unions of `TypedDict`s with `Literal` keys.