diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index eb09335da40533..4181b74e6bea32 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -23,9 +23,6 @@ //! Note that all lower and upper bounds in a constraint must be fully static. We take the bottom //! and top materializations of the types to remove any gradual forms if needed. //! -//! Lower and upper bounds must also be normalized. This lets us identify, for instance, -//! two constraints with equivalent but differently ordered unions as their bounds. -//! //! NOTE: This module is currently in a transitional state. We've added the BDD [`ConstraintSet`] //! representation, and updated all of our property checks to build up a constraint set and then //! check whether it is ever or always satisfiable, as appropriate. We are not yet inferring @@ -435,9 +432,6 @@ impl<'db> ConstrainedTypeVar<'db> { return Node::AlwaysTrue; } - let lower = lower.normalized(db); - let upper = upper.normalized(db); - // We have an (arbitrary) ordering for typevars. If the upper and/or lower bounds are // typevars, we have to ensure that the bounds are "later" according to that order than the // typevar being constrained. @@ -510,6 +504,15 @@ impl<'db> ConstrainedTypeVar<'db> { ConstraintAssignment::Negative(self) } + fn normalized(self, db: &'db dyn Db) -> Self { + Self::new( + db, + self.typevar(db), + self.lower(db).normalized(db), + self.upper(db).normalized(db), + ) + } + /// Defines the ordering of the variables in a constraint set BDD. /// /// If we only care about _correctness_, we can choose any ordering that we want, as long as @@ -542,9 +545,8 @@ impl<'db> ConstrainedTypeVar<'db> { /// Returns the intersection of two range constraints, or `None` if the intersection is empty. fn intersect(self, db: &'db dyn Db, other: Self) -> Option { // (s₁ ≤ α ≤ t₁) ∧ (s₂ ≤ α ≤ t₂) = (s₁ ∪ s₂) ≤ α ≤ (t₁ ∩ t₂)) - let lower = UnionType::from_elements(db, [self.lower(db), other.lower(db)]).normalized(db); - let upper = - IntersectionType::from_elements(db, [self.upper(db), other.upper(db)]).normalized(db); + let lower = UnionType::from_elements(db, [self.lower(db), other.lower(db)]); + let upper = IntersectionType::from_elements(db, [self.upper(db), other.upper(db)]); // If `lower ≰ upper`, then the intersection is empty, since there is no type that is both // greater than `lower`, and less than `upper`. @@ -1217,7 +1219,7 @@ impl<'db> Node<'db> { Node::AlwaysFalse => {} Node::AlwaysTrue => self.clauses.push(self.current_clause.clone()), Node::Interior(interior) => { - let interior_constraint = interior.constraint(db); + let interior_constraint = interior.constraint(db).normalized(db); self.current_clause.push(interior_constraint.when_true()); self.visit_node(db, interior.if_true(db)); self.current_clause.pop(); @@ -1751,6 +1753,8 @@ impl<'db> InteriorNode<'db> { // non-empty. match left_constraint.intersect(db, right_constraint) { Some(intersection_constraint) => { + let intersection_constraint = intersection_constraint.normalized(db); + // If the intersection is non-empty, we need to create a new constraint to // represent that intersection. We also need to add the new constraint to our // seen set and (if we haven't already seen it) to the to-visit queue. @@ -2013,33 +2017,49 @@ struct SequentMap<'db> { /// Sequents of the form `C₁ ∧ C₂ → false` impossibilities: FxHashSet<(ConstrainedTypeVar<'db>, ConstrainedTypeVar<'db>)>, /// Sequents of the form `C₁ ∧ C₂ → D` - pair_implications: - FxHashMap<(ConstrainedTypeVar<'db>, ConstrainedTypeVar<'db>), Vec>>, + pair_implications: FxHashMap< + (ConstrainedTypeVar<'db>, ConstrainedTypeVar<'db>), + FxHashSet>, + >, /// Sequents of the form `C → D` - single_implications: FxHashMap, Vec>>, + single_implications: FxHashMap, FxHashSet>>, /// Constraints that we have already processed processed: FxHashSet>, + /// Constraints that enqueued to be processed + enqueued: Vec>, } impl<'db> SequentMap<'db> { fn add(&mut self, db: &'db dyn Db, constraint: ConstrainedTypeVar<'db>) { - // If we've already seen this constraint, we can skip it. - if !self.processed.insert(constraint) { - return; - } + self.enqueue_constraint(constraint); - // Otherwise, check this constraint against all of the other ones we've seen so far, seeing - // if they're related to each other. - let processed = std::mem::take(&mut self.processed); - for other in &processed { - if constraint != *other { - self.add_sequents_for_pair(db, constraint, *other); + while let Some(constraint) = self.enqueued.pop() { + // If we've already processed this constraint, we can skip it. + if !self.processed.insert(constraint) { + continue; } + + // Otherwise, check this constraint against all of the other ones we've seen so far, seeing + // if they're related to each other. + let processed = std::mem::take(&mut self.processed); + for other in &processed { + if constraint != *other { + self.add_sequents_for_pair(db, constraint, *other); + } + } + self.processed = processed; + + // And see if we can create any sequents from the constraint on its own. + self.add_sequents_for_single(db, constraint); } - self.processed = processed; + } - // And see if we can create any sequents from the constraint on its own. - self.add_sequents_for_single(db, constraint); + fn enqueue_constraint(&mut self, constraint: ConstrainedTypeVar<'db>) { + // If we've already processed this constraint, we can skip it. + if self.processed.contains(&constraint) { + return; + } + self.enqueued.push(constraint); } fn pair_key( @@ -2071,13 +2091,14 @@ impl<'db> SequentMap<'db> { ante2: ConstrainedTypeVar<'db>, post: ConstrainedTypeVar<'db>, ) { - if ante1 == post || ante2 == post { + // If either antecedent implies the consequent on its own, this new sequent is redundant. + if ante1.implies(db, post) || ante2.implies(db, post) { return; } self.pair_implications .entry(Self::pair_key(db, ante1, ante2)) .or_default() - .push(post); + .insert(post); } fn add_single_implication( @@ -2088,7 +2109,10 @@ impl<'db> SequentMap<'db> { if ante == post { return; } - self.single_implications.entry(ante).or_default().push(post); + self.single_implications + .entry(ante) + .or_default() + .insert(post); } fn add_sequents_for_single(&mut self, db: &'db dyn Db, constraint: ConstrainedTypeVar<'db>) { @@ -2104,32 +2128,31 @@ impl<'db> SequentMap<'db> { let lower = constraint.lower(db); let upper = constraint.upper(db); - match (lower, upper) { + let post_constraint = match (lower, upper) { // Case 1 (Type::TypeVar(lower_typevar), Type::TypeVar(upper_typevar)) => { if !lower_typevar.is_same_typevar_as(db, upper_typevar) { - let post_constraint = - ConstrainedTypeVar::new(db, lower_typevar, Type::Never, upper); - self.add_single_implication(constraint, post_constraint); + ConstrainedTypeVar::new(db, lower_typevar, Type::Never, upper) + } else { + return; } } // Case 2 (Type::TypeVar(lower_typevar), _) => { - let post_constraint = - ConstrainedTypeVar::new(db, lower_typevar, Type::Never, upper); - self.add_single_implication(constraint, post_constraint); + ConstrainedTypeVar::new(db, lower_typevar, Type::Never, upper) } // Case 3 (_, Type::TypeVar(upper_typevar)) => { - let post_constraint = - ConstrainedTypeVar::new(db, upper_typevar, lower, Type::object()); - self.add_single_implication(constraint, post_constraint); + ConstrainedTypeVar::new(db, upper_typevar, lower, Type::object()) } - _ => {} - } + _ => return, + }; + + self.add_single_implication(constraint, post_constraint); + self.enqueue_constraint(post_constraint); } fn add_sequents_for_pair( @@ -2240,6 +2263,7 @@ impl<'db> SequentMap<'db> { let post_constraint = ConstrainedTypeVar::new(db, constrained_typevar, new_lower, new_upper); self.add_pair_implication(db, left_constraint, right_constraint, post_constraint); + self.enqueue_constraint(post_constraint); } fn add_mutual_sequents_for_same_typevars( @@ -2270,6 +2294,7 @@ impl<'db> SequentMap<'db> { _ => return, }; self.add_pair_implication(db, left_constraint, right_constraint, post_constraint); + self.enqueue_constraint(post_constraint); }; try_one_direction(left_constraint, right_constraint); @@ -2282,6 +2307,18 @@ impl<'db> SequentMap<'db> { left_constraint: ConstrainedTypeVar<'db>, right_constraint: ConstrainedTypeVar<'db>, ) { + // These might seem redundant with the intersection check below, since `a → b` means that + // `a ∧ b = a`. But we are not normalizing constraint bounds, and these clauses help us + // identify constraints that are identical besides e.g. ordering of union/intersection + // elements. (For instance, when processing `T ≤ τ₁ & τ₂` and `T ≤ τ₂ & τ₁`, these clauses + // would add sequents for `(T ≤ τ₁ & τ₂) → (T ≤ τ₂ & τ₁)` and vice versa.) + if left_constraint.implies(db, right_constraint) { + self.add_single_implication(left_constraint, right_constraint); + } + if right_constraint.implies(db, left_constraint) { + self.add_single_implication(right_constraint, left_constraint); + } + match left_constraint.intersect(db, right_constraint) { Some(intersection_constraint) => { self.add_pair_implication( @@ -2292,6 +2329,7 @@ impl<'db> SequentMap<'db> { ); self.add_single_implication(intersection_constraint, left_constraint); self.add_single_implication(intersection_constraint, right_constraint); + self.enqueue_constraint(intersection_constraint); } None => { self.add_impossibility(db, left_constraint, right_constraint);