diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index c1a9a404ba53a..83881611a39bb 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -74,6 +74,7 @@ use std::ops::Range; use itertools::Itertools; use rustc_hash::{FxHashMap, FxHashSet}; use salsa::plumbing::AsId; +use smallvec::SmallVec; use crate::types::generics::{GenericContext, InferableTypeVars, Specialization}; use crate::types::visitor::{ @@ -83,7 +84,7 @@ use crate::types::{ BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeVarBoundOrConstraints, UnionType, walk_bound_type_var_type, }; -use crate::{Db, FxOrderMap, FxOrderSet}; +use crate::{Db, FxIndexMap, FxOrderSet}; /// An extension trait for building constraint sets from [`Option`] values. pub(crate) trait OptionConstraintsExtension { @@ -1004,9 +1005,8 @@ impl<'db> Node<'db> { Node::AlwaysTrue => {} Node::AlwaysFalse => {} Node::Interior(interior) => { - let map = interior.sequent_map(db); - let mut path = PathAssignments::default(); - self.for_each_path_inner(db, &mut f, map, &mut path); + let mut path = interior.path_assignments(db); + self.for_each_path_inner(db, &mut f, &mut path); } } } @@ -1015,7 +1015,6 @@ impl<'db> Node<'db> { self, db: &'db dyn Db, f: &mut dyn FnMut(&PathAssignments<'db>), - map: &SequentMap<'db>, path: &mut PathAssignments<'db>, ) { match self { @@ -1024,11 +1023,11 @@ impl<'db> Node<'db> { Node::Interior(interior) => { let constraint = interior.constraint(db); let source_order = interior.source_order(db); - path.walk_edge(db, map, constraint.when_true(), source_order, |path, _| { - interior.if_true(db).for_each_path_inner(db, f, map, path); + path.walk_edge(db, constraint.when_true(), source_order, |path, _| { + interior.if_true(db).for_each_path_inner(db, f, path); }); - path.walk_edge(db, map, constraint.when_false(), source_order, |path, _| { - interior.if_false(db).for_each_path_inner(db, f, map, path); + path.walk_edge(db, constraint.when_false(), source_order, |path, _| { + interior.if_false(db).for_each_path_inner(db, f, path); }); } } @@ -1040,19 +1039,13 @@ impl<'db> Node<'db> { Node::AlwaysTrue => true, Node::AlwaysFalse => false, Node::Interior(interior) => { - let map = interior.sequent_map(db); - let mut path = PathAssignments::default(); - self.is_always_satisfied_inner(db, map, &mut path) + let mut path = interior.path_assignments(db); + self.is_always_satisfied_inner(db, &mut path) } } } - fn is_always_satisfied_inner( - self, - db: &'db dyn Db, - map: &SequentMap<'db>, - path: &mut PathAssignments<'db>, - ) -> bool { + fn is_always_satisfied_inner(self, db: &'db dyn Db, path: &mut PathAssignments<'db>) -> bool { match self { Node::AlwaysTrue => true, Node::AlwaysFalse => false, @@ -1063,10 +1056,8 @@ impl<'db> Node<'db> { let constraint = interior.constraint(db); let source_order = interior.source_order(db); let true_always_satisfied = path - .walk_edge(db, map, constraint.when_true(), source_order, |path, _| { - interior - .if_true(db) - .is_always_satisfied_inner(db, map, path) + .walk_edge(db, constraint.when_true(), source_order, |path, _| { + interior.if_true(db).is_always_satisfied_inner(db, path) }) .unwrap_or(true); if !true_always_satisfied { @@ -1074,10 +1065,8 @@ impl<'db> Node<'db> { } // Ditto for the if_false branch - path.walk_edge(db, map, constraint.when_false(), source_order, |path, _| { - interior - .if_false(db) - .is_always_satisfied_inner(db, map, path) + path.walk_edge(db, constraint.when_false(), source_order, |path, _| { + interior.if_false(db).is_always_satisfied_inner(db, path) }) .unwrap_or(true) } @@ -1090,19 +1079,13 @@ impl<'db> Node<'db> { Node::AlwaysTrue => false, Node::AlwaysFalse => true, Node::Interior(interior) => { - let map = interior.sequent_map(db); - let mut path = PathAssignments::default(); - self.is_never_satisfied_inner(db, map, &mut path) + let mut path = interior.path_assignments(db); + self.is_never_satisfied_inner(db, &mut path) } } } - fn is_never_satisfied_inner( - self, - db: &'db dyn Db, - map: &SequentMap<'db>, - path: &mut PathAssignments<'db>, - ) -> bool { + fn is_never_satisfied_inner(self, db: &'db dyn Db, path: &mut PathAssignments<'db>) -> bool { match self { Node::AlwaysTrue => false, Node::AlwaysFalse => true, @@ -1113,8 +1096,8 @@ impl<'db> Node<'db> { let constraint = interior.constraint(db); let source_order = interior.source_order(db); let true_never_satisfied = path - .walk_edge(db, map, constraint.when_true(), source_order, |path, _| { - interior.if_true(db).is_never_satisfied_inner(db, map, path) + .walk_edge(db, constraint.when_true(), source_order, |path, _| { + interior.if_true(db).is_never_satisfied_inner(db, path) }) .unwrap_or(true); if !true_never_satisfied { @@ -1122,10 +1105,8 @@ impl<'db> Node<'db> { } // Ditto for the if_false branch - path.walk_edge(db, map, constraint.when_false(), source_order, |path, _| { - interior - .if_false(db) - .is_never_satisfied_inner(db, map, path) + path.walk_edge(db, constraint.when_false(), source_order, |path, _| { + interior.if_false(db).is_never_satisfied_inner(db, path) }) .unwrap_or(true) } @@ -1411,13 +1392,12 @@ impl<'db> Node<'db> { self, db: &'db dyn Db, should_remove: &mut dyn FnMut(ConstrainedTypeVar<'db>) -> bool, - map: &SequentMap<'db>, path: &mut PathAssignments<'db>, ) -> Self { match self { Node::AlwaysTrue => Node::AlwaysTrue, Node::AlwaysFalse => Node::AlwaysFalse, - Node::Interior(interior) => interior.abstract_one_inner(db, should_remove, map, path), + Node::Interior(interior) => interior.abstract_one_inner(db, should_remove, path), } } @@ -2007,8 +1987,7 @@ impl<'db> InteriorNode<'db> { #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] fn exists_one(self, db: &'db dyn Db, bound_typevar: BoundTypeVarIdentity<'db>) -> Node<'db> { - let map = self.sequent_map(db); - let mut path = PathAssignments::default(); + let mut path = self.path_assignments(db); let mentions_typevar = |ty: Type<'db>| match ty { Type::TypeVar(haystack) => haystack.identity(db) == bound_typevar, _ => false, @@ -2034,15 +2013,13 @@ impl<'db> InteriorNode<'db> { } false }, - map, &mut path, ) } #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] fn retain_one(self, db: &'db dyn Db, bound_typevar: BoundTypeVarIdentity<'db>) -> Node<'db> { - let map = self.sequent_map(db); - let mut path = PathAssignments::default(); + let mut path = self.path_assignments(db); self.abstract_one_inner( db, // Remove any node that constrains some other typevar than `bound_typevar`, and any @@ -2058,7 +2035,6 @@ impl<'db> InteriorNode<'db> { } false }, - map, &mut path, ) } @@ -2067,7 +2043,6 @@ impl<'db> InteriorNode<'db> { self, db: &'db dyn Db, should_remove: &mut dyn FnMut(ConstrainedTypeVar<'db>) -> bool, - map: &SequentMap<'db>, path: &mut PathAssignments<'db>, ) -> Node<'db> { let self_constraint = self.constraint(db); @@ -2089,13 +2064,10 @@ impl<'db> InteriorNode<'db> { let if_true = path .walk_edge( db, - map, self_constraint.when_true(), self_source_order, |path, new_range| { - let branch = - self.if_true(db) - .abstract_one_inner(db, should_remove, map, path); + let branch = self.if_true(db).abstract_one_inner(db, should_remove, path); path.assignments[new_range] .iter() .filter(|(assignment, _)| { @@ -2115,13 +2087,12 @@ impl<'db> InteriorNode<'db> { let if_false = path .walk_edge( db, - map, self_constraint.when_false(), self_source_order, |path, new_range| { - let branch = - self.if_false(db) - .abstract_one_inner(db, should_remove, map, path); + let branch = self + .if_false(db) + .abstract_one_inner(db, should_remove, path); path.assignments[new_range] .iter() .filter(|(assignment, _)| { @@ -2144,24 +2115,19 @@ impl<'db> InteriorNode<'db> { let if_true = path .walk_edge( db, - map, self_constraint.when_true(), self_source_order, - |path, _| { - self.if_true(db) - .abstract_one_inner(db, should_remove, map, path) - }, + |path, _| self.if_true(db).abstract_one_inner(db, should_remove, path), ) .unwrap_or(Node::AlwaysFalse); let if_false = path .walk_edge( db, - map, self_constraint.when_false(), self_source_order, |path, _| { self.if_false(db) - .abstract_one_inner(db, should_remove, map, path) + .abstract_one_inner(db, should_remove, path) }, ) .unwrap_or(Node::AlwaysFalse); @@ -2209,35 +2175,18 @@ impl<'db> InteriorNode<'db> { } } - /// Returns a sequent map for this BDD, which records the relationships between the constraints - /// that appear in the BDD. - #[salsa::tracked( - returns(ref), - cycle_initial=sequent_map_cycle_initial, - heap_size=ruff_memory_usage::heap_size, - )] - fn sequent_map(self, db: &'db dyn Db) -> SequentMap<'db> { - tracing::trace!( - target: "ty_python_semantic::types::constraints::SequentMap", - constraints = %Node::Interior(self).display(db), - "create sequent map", - ); - + fn path_assignments(self, db: &'db dyn Db) -> PathAssignments<'db> { // Sort the constraints in this BDD by their `source_order`s before adding them to the // sequent map. This ensures that constraints appear in the sequent map in a stable order. // The constraints mentioned in a BDD should all have distinct `source_order`s, so an // unstable sort is fine. - let mut constraints = Vec::new(); + let mut constraints: SmallVec<[_; 8]> = SmallVec::new(); Node::Interior(self).for_each_constraint(db, &mut |constraint, source_order| { constraints.push((constraint, source_order)); }); constraints.sort_unstable_by_key(|(_, source_order)| *source_order); - let mut map = SequentMap::default(); - for (constraint, _) in constraints { - map.add(db, constraint); - } - map + PathAssignments::new(constraints.into_iter().map(|(constraint, _)| constraint)) } /// Returns a simplified version of a BDD. @@ -2646,14 +2595,6 @@ impl<'db> InteriorNode<'db> { } } -fn sequent_map_cycle_initial<'db>( - _db: &'db dyn Db, - _id: salsa::Id, - _self: InteriorNode<'db>, -) -> SequentMap<'db> { - SequentMap::default() -} - /// An assignment of one BDD variable to either `true` or `false`. (When evaluating a BDD, we /// must provide an assignment for each variable present in the BDD.) #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, salsa::Update)] @@ -2783,7 +2724,19 @@ impl<'db> ConstraintAssignment<'db> { /// /// - `C → D`: This indicates that `C` on its own is enough to imply `D`. Any path that assumes `C` /// holds but `D` does _not_ is impossible and can be pruned. -#[derive(Debug, Default, Eq, PartialEq, get_size2::GetSize, salsa::Update)] +/// +/// Sequent maps are primarily used when walking a BDD path with a [`PathAssignments`]. The +/// `PathAssignments` will hold a sequent map containing all of the constraints that are +/// encountered during the walk. It builds up its sequent map lazily, so that it only has to +/// include sequents for the constraints that are actually encountered. However, we also don't want +/// to perform duplicate work if we perform multiple BDD walks on the same constraint set. The +/// [`for_constraint`][Self::for_constraint] and [`for_constraint_pair`][Self::for_constraint_pair] +/// methods are salsa-tracked, to ensure that we only perform them once for any particular +/// constraint or pair of constraints. `PathAssignments` invokes these methods when it encounters a +/// new constraint, and then merges those cached sequents into its own sequent map. (That means we +/// also share the work of calculating the sequent map across `PathAssignments` for _different_ +/// constraint sets.) +#[derive(Clone, Debug, Default, Eq, PartialEq, get_size2::GetSize, salsa::Update)] struct SequentMap<'db> { /// Sequents of the form `¬C₁ → false` single_tautologies: FxHashSet>, @@ -2796,54 +2749,79 @@ struct SequentMap<'db> { >, /// Sequents of the form `C → D` single_implications: FxHashMap, FxOrderSet>>, - /// Constraints that we have already processed - processed: FxHashSet>, - /// Constraints that enqueued to be processed - enqueued: Vec>, } impl<'db> SequentMap<'db> { - fn add(&mut self, db: &'db dyn Db, constraint: ConstrainedTypeVar<'db>) { - self.enqueue_constraint(constraint); - - while let Some(constraint) = self.enqueued.pop() { - // If we've already processed this constraint, we can skip it. - if !self.processed.insert(constraint) { - continue; - } - - // First see if we can create any sequents from the constraint on its own. + /// Returns a sequent map containing the sequents that we can infer from a single constraint in + /// isolation. This method is salsa-tracked so that we only perform this work once per + /// constraint. + fn for_constraint(db: &'db dyn Db, constraint: ConstrainedTypeVar<'db>) -> &'db Self { + #[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size)] + fn for_constraint_inner<'db>( + db: &'db dyn Db, + constraint: ConstrainedTypeVar<'db>, + ) -> SequentMap<'db> { tracing::trace!( target: "ty_python_semantic::types::constraints::SequentMap", constraint = %constraint.display(db), "add sequents for constraint", ); - self.add_sequents_for_single(db, constraint); + let mut map = SequentMap::default(); + map.add_sequents_for_single(db, constraint); + map + } - // Then check this constraint against all of the other ones we've seen so far, seeing - // if they're related to each other. - let processed = std::mem::take(&mut self.processed); - for other in &processed { - if constraint != *other { - tracing::trace!( - target: "ty_python_semantic::types::constraints::SequentMap", - left = %constraint.display(db), - right = %other.display(db), - "add sequents for constraint pair", - ); - self.add_sequents_for_pair(db, constraint, *other); - } - } - self.processed = processed; + for_constraint_inner(db, constraint) + } + + /// Returns a sequent map containing the sequents that we can infer from a pair of constraints. + /// This method is salsa-tracked so that we only perform this work once per constraint pair. + /// + /// (Note that this method is _not_ commutative; you should provide `left` and `right` in the + /// order that they appear in the source code, so that we can construct derived constraints + /// that retain that ordering.) + fn for_constraint_pair( + db: &'db dyn Db, + left: ConstrainedTypeVar<'db>, + right: ConstrainedTypeVar<'db>, + ) -> &'db Self { + #[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size)] + fn for_constraint_pair_inner<'db>( + db: &'db dyn Db, + left: ConstrainedTypeVar<'db>, + right: ConstrainedTypeVar<'db>, + ) -> SequentMap<'db> { + tracing::trace!( + target: "ty_python_semantic::types::constraints::SequentMap", + left = %left.display(db), + right = %right.display(db), + "add sequents for constraint pair", + ); + let mut map = SequentMap::default(); + map.add_sequents_for_pair(db, left, right); + map } + + for_constraint_pair_inner(db, left, right) } - fn enqueue_constraint(&mut self, constraint: ConstrainedTypeVar<'db>) { - // If we've already processed this constraint, we can skip it. - if self.processed.contains(&constraint) { - return; + /// Merges the sequents from another sequent map into this one. + fn merge(&mut self, db: &'db dyn Db, other: &Self) { + self.single_tautologies.extend(&other.single_tautologies); + self.pair_impossibilities + .extend(&other.pair_impossibilities); + for ((ante1, ante2), post) in &other.pair_implications { + self.pair_implications + .entry(Self::pair_key(db, *ante1, *ante2)) + .or_default() + .extend(post); + } + for (ante, post) in &other.single_implications { + self.single_implications + .entry(*ante) + .or_default() + .extend(post); } - self.enqueued.push(constraint); } fn pair_key( @@ -2987,7 +2965,6 @@ impl<'db> SequentMap<'db> { }; self.add_single_implication(db, constraint, post_constraint); - self.enqueue_constraint(post_constraint); } fn add_sequents_for_pair( @@ -3127,7 +3104,6 @@ impl<'db> SequentMap<'db> { let post_constraint = ConstrainedTypeVar::new(db, constrained_typevar, new_lower, new_upper); self.add_pair_implication(db, left_constraint, right_constraint, post_constraint); - self.enqueue_constraint(post_constraint); } fn add_mutual_sequents_for_same_typevars( @@ -3177,7 +3153,6 @@ impl<'db> SequentMap<'db> { _ => return, }; self.add_pair_implication(db, left_constraint, right_constraint, post_constraint); - self.enqueue_constraint(post_constraint); }; try_one_direction(left_constraint, right_constraint); @@ -3231,7 +3206,6 @@ impl<'db> SequentMap<'db> { ); self.add_single_implication(db, intersection_constraint, left_constraint); self.add_single_implication(db, intersection_constraint, right_constraint); - self.enqueue_constraint(intersection_constraint); } // The sequent map only needs to include constraints that might appear in a BDD. If the @@ -3318,12 +3292,29 @@ impl<'db> SequentMap<'db> { /// The collection of constraints that we know to be true or false at a certain point when /// traversing a BDD. -#[derive(Debug, Default)] +#[derive(Debug)] pub(crate) struct PathAssignments<'db> { - assignments: FxOrderMap, usize>, + map: SequentMap<'db>, + assignments: FxIndexMap, usize>, + /// Constraints that we have discovered, mapped to whether we have processed them yet. (This + /// ensures a stable order for all of the derived constraints that we create, while still + /// letting us create them lazily.) + discovered: FxIndexMap, bool>, } impl<'db> PathAssignments<'db> { + fn new(constraints: impl IntoIterator>) -> Self { + let discovered = constraints + .into_iter() + .map(|constraint| (constraint, false)) + .collect(); + Self { + map: SequentMap::default(), + assignments: FxIndexMap::default(), + discovered, + } + } + /// Walks one of the outgoing edges of an internal BDD node. `assignment` describes the /// constraint that the BDD node checks, and whether we are following the `if_true` or /// `if_false` edge. @@ -3349,7 +3340,6 @@ impl<'db> PathAssignments<'db> { fn walk_edge( &mut self, db: &'db dyn Db, - map: &SequentMap<'db>, assignment: ConstraintAssignment<'db>, source_order: usize, f: impl FnOnce(&mut Self, Range) -> R, @@ -3369,7 +3359,7 @@ impl<'db> PathAssignments<'db> { edge = %assignment.display(db), "walk edge", ); - let found_conflict = self.add_assignment(db, map, assignment, source_order); + let found_conflict = self.add_assignment(db, assignment, source_order); let result = if found_conflict.is_err() { // If that results in the path now being impossible due to a contradiction, return // without invoking the callback. @@ -3413,13 +3403,34 @@ impl<'db> PathAssignments<'db> { self.assignments.contains_key(&assignment) } + /// Update our sequent map to ensure that it holds all of the sequents that involve the given + /// constraint. We do not calculate the new sequents directly. Instead, we call + /// [`SequentMap::for_constraint`] and [`for_constraint_pair`][SequentMap::for_constraint_pair] + /// to calculate _and cache_ the constraints, so that if we walk another constraint set + /// containing this constraint, we reuse the work to calculate its sequents. + fn discover_constraint(&mut self, db: &'db dyn Db, constraint: ConstrainedTypeVar<'db>) { + // If we've already processed this constraint, we can skip it. + let existing = self.discovered.insert(constraint, true); + let already_processed = existing.is_some_and(|existing| existing); + if already_processed { + return; + } + + let single_map = SequentMap::for_constraint(db, constraint); + self.map.merge(db, single_map); + + for existing in self.discovered.keys().dropping_back(1) { + let pair_map = SequentMap::for_constraint_pair(db, *existing, constraint); + self.map.merge(db, pair_map); + } + } + /// Adds a new assignment, along with any derived information that we can infer from the new /// assignment combined with the assignments we've already seen. If any of this causes the path /// to become invalid, due to a contradiction, returns a [`PathAssignmentConflict`] error. fn add_assignment( &mut self, db: &'db dyn Db, - map: &SequentMap<'db>, assignment: ConstraintAssignment<'db>, source_order: usize, ) -> Result<(), PathAssignmentConflict> { @@ -3453,7 +3464,9 @@ impl<'db> PathAssignments<'db> { // don't anticipate the sequent maps to be very large. We might consider avoiding the // brute-force search. - for ante in &map.single_tautologies { + self.discover_constraint(db, assignment.constraint()); + + for ante in &self.map.single_tautologies { if self.assignment_holds(ante.when_false()) { // The sequent map says (ante1) is always true, and the current path asserts that // it's false. @@ -3470,7 +3483,7 @@ impl<'db> PathAssignments<'db> { } } - for (ante1, ante2) in &map.pair_impossibilities { + for (ante1, ante2) in &self.map.pair_impossibilities { if self.assignment_holds(ante1.when_true()) && self.assignment_holds(ante2.when_true()) { // The sequent map says (ante1 ∧ ante2) is an impossible combination, and the @@ -3489,24 +3502,29 @@ impl<'db> PathAssignments<'db> { } } - for ((ante1, ante2), posts) in &map.pair_implications { + let mut new_constraints = Vec::new(); + for ((ante1, ante2), posts) in &self.map.pair_implications { for post in posts { if self.assignment_holds(ante1.when_true()) && self.assignment_holds(ante2.when_true()) { - self.add_assignment(db, map, post.when_true(), source_order)?; + new_constraints.push(*post); } } } - for (ante, posts) in &map.single_implications { + for (ante, posts) in &self.map.single_implications { for post in posts { if self.assignment_holds(ante.when_true()) { - self.add_assignment(db, map, post.when_true(), source_order)?; + new_constraints.push(*post); } } } + for new_constraint in new_constraints { + self.add_assignment(db, new_constraint.when_true(), source_order)?; + } + Ok(()) } }