diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/match.md b/crates/ty_python_semantic/resources/mdtest/narrow/match.md index 8d95a1d5404f2..2e1ad06abd8e3 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/match.md @@ -252,7 +252,7 @@ def _(x: Literal["foo", b"bar"] | int): pass case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int pass - case _ if reveal_type(x): # revealed: Literal["foo", b"bar"] | int + case _ if reveal_type(x): # revealed: int | Literal["foo", b"bar"] pass ``` @@ -350,6 +350,45 @@ except ValueError: pass ``` +## Narrowing is preserved when a terminal branch prevents a path from flowing through + +When one branch of a `match` statement is terminal (e.g. contains `raise`), narrowing from the +non-terminal branches is preserved after the merge point. + +```py +class A: ... +class B: ... +class C: ... + +def _(x: A | B | C): + match x: + case A(): + pass + case B(): + pass + case _: + raise ValueError() + + reveal_type(x) # revealed: B | (A & ~B) +``` + +Reassignment in non-terminal branches is also preserved when the default branch is terminal: + +```py +def _(number_of_periods: int | None, interval: str): + match interval: + case "monthly": + if number_of_periods is None: + number_of_periods = 1 + case "daily": + if number_of_periods is None: + number_of_periods = 30 + case _: + raise ValueError("unsupported interval") + + reveal_type(number_of_periods) # revealed: int +``` + ## Narrowing tagged unions of tuples Narrow unions of tuples based on literal tag elements in `match` statements: diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md b/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md index fbbe5794d8355..76d96d746baf1 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md @@ -49,3 +49,196 @@ def _(x: int | None): reveal_type(x) # revealed: int ``` + +## Narrowing is preserved when a terminal branch prevents a path from flowing through + +When one branch of an if/elif/else is terminal (e.g. contains `return`), narrowing from the +non-terminal branches is preserved after the merge point. + +```py +class A: ... +class B: ... +class C: ... + +def _(x: A | B | C): + if isinstance(x, A): + pass + elif isinstance(x, B): + pass + else: + return + + # Only the if-branch (A) and elif-branch (B) flow through. + # The else-branch returned, so its narrowing doesn't participate. + reveal_type(x) # revealed: B | (A & ~B) +``` + +## Narrowing is preserved with multiple terminal branches + +```py +class A: ... +class B: ... +class C: ... +class D: ... + +def _(x: A | B | C | D): + if isinstance(x, A): + return + elif isinstance(x, B): + pass + elif isinstance(x, C): + pass + else: + return + + # Only the elif-B and elif-C branches flow through. + reveal_type(x) # revealed: (C & ~A) | (B & ~A & ~C) +``` + +## Multiple sequential if-statements don't leak narrowing + +After a complete if/else where both branches flow through (no terminal), narrowing should be +cancelled out at the merge point. + +```py +class A: ... +class B: ... +class C: ... + +def _(x: A | B | C): + if isinstance(x, A): + pass + else: + pass + + # Narrowing cancels out: both paths flow, so type is unchanged. + reveal_type(x) # revealed: A | B | C + + if isinstance(x, B): + pass + else: + pass + + # Second if-statement's narrowing also cancels out. + reveal_type(x) # revealed: A | B | C +``` + +## Narrowing after a `NoReturn` call in one branch + +When a branch calls a function that returns `NoReturn`/`Never`, we know that branch terminates and +doesn't contribute to the type after the if statement. + +```py +import sys + +def _(val: int | None): + if val is None: + sys.exit() + reveal_type(val) # revealed: int +``` + +This also works when the `NoReturn` function is called in the else branch: + +```py +import sys + +def _(val: int | None): + if val is not None: + pass + else: + sys.exit() + reveal_type(val) # revealed: int +``` + +And for elif branches: + +```py +import sys + +def _(val: int | str | None): + if val is None: + sys.exit() + elif isinstance(val, int): + pass + else: + sys.exit() + reveal_type(val) # revealed: int +``` + +## Narrowing through always-true branches + +When a terminal (`return`) is inside an always-true branch, narrowing propagates through because the +else-branch is unreachable and contributes `Never` to the union. + +```py +def _(x: int | None): + if True: + if x is None: + return + reveal_type(x) # revealed: int + reveal_type(x) # revealed: int +``` + +```py +def _(x: int | None): + if 1 + 1 == 2: + if x is None: + return + reveal_type(x) # revealed: int + + # TODO: should be `int` (the else-branch of `1 + 1 == 2` is unreachable) + reveal_type(x) # revealed: int | None +``` + +This also works when the always-true condition is nested inside a narrowing branch: + +```py +def _(x: int | None): + if x is None: + if 1 + 1 == 2: + return + + # TODO: should be `int` (the inner always-true branch makes the outer + # if-branch terminal) + reveal_type(x) # revealed: int | None +``` + +## Narrowing from `assert` should not affect reassigned variables + +When a variable is reassigned after an `assert`, the narrowing from the assert should not apply to +the new value. + +```py +def foo(arg: int) -> int | None: + return None + +def bar() -> None: + v = foo(1) + assert v is None + + v = foo(2) + # v was reassigned, so the assert narrowing shouldn't apply + reveal_type(v) # revealed: int | None +``` + +## Narrowing from `NoReturn` should not affect reassigned variables + +When a variable is narrowed due to a `NoReturn` call in one branch and then reassigned, the +narrowing should only apply before the reassignment, not after. + +```py +import sys + +def foo() -> int | None: + return 3 + +def bar(): + v = foo() + if v is None: + sys.exit() + reveal_type(v) # revealed: int + + v = foo() + # v was reassigned, so any narrowing shouldn't apply + reveal_type(v) # revealed: int | None +``` diff --git a/crates/ty_python_semantic/resources/mdtest/terminal_statements.md b/crates/ty_python_semantic/resources/mdtest/terminal_statements.md index 25d458ae67b31..05bf71894ec61 100644 --- a/crates/ty_python_semantic/resources/mdtest/terminal_statements.md +++ b/crates/ty_python_semantic/resources/mdtest/terminal_statements.md @@ -618,9 +618,7 @@ def g(x: int | None): if x is None: sys.exit(1) - # TODO: should be just `int`, not `int | None` - # See https://github.com/astral-sh/ty/issues/685 - reveal_type(x) # revealed: int | None + reveal_type(x) # revealed: int ``` ### Possibly unresolved diagnostics diff --git a/crates/ty_python_semantic/src/lib.rs b/crates/ty_python_semantic/src/lib.rs index 9adaff6906c4c..b693d96164bf7 100644 --- a/crates/ty_python_semantic/src/lib.rs +++ b/crates/ty_python_semantic/src/lib.rs @@ -36,7 +36,6 @@ pub mod ast_node_ref; mod db; mod dunder_all; pub mod lint; -pub(crate) mod list; mod node_key; pub(crate) mod place; mod program; diff --git a/crates/ty_python_semantic/src/list.rs b/crates/ty_python_semantic/src/list.rs deleted file mode 100644 index 547c69e47e1a8..0000000000000 --- a/crates/ty_python_semantic/src/list.rs +++ /dev/null @@ -1,745 +0,0 @@ -//! Sorted, arena-allocated association lists -//! -//! An [_association list_][alist], which is a linked list of key/value pairs. We additionally -//! guarantee that the elements of an association list are sorted (by their keys), and that they do -//! not contain any entries with duplicate keys. -//! -//! Association lists have fallen out of favor in recent decades, since you often need operations -//! that are inefficient on them. In particular, looking up a random element by index is O(n), just -//! like a linked list; and looking up an element by key is also O(n), since you must do a linear -//! scan of the list to find the matching element. The typical implementation also suffers from -//! poor cache locality and high memory allocation overhead, since individual list cells are -//! typically allocated separately from the heap. We solve that last problem by storing the cells -//! of an association list in an [`IndexVec`] arena. -//! -//! We exploit structural sharing where possible, reusing cells across multiple lists when we can. -//! That said, we don't guarantee that lists are canonical — it's entirely possible for two lists -//! with identical contents to use different list cells and have different identifiers. -//! -//! Given all of this, association lists have the following benefits: -//! -//! - Lists can be represented by a single 32-bit integer (the index into the arena of the head of -//! the list). -//! - Lists can be cloned in constant time, since the underlying cells are immutable. -//! - Lists can be combined quickly (for both intersection and union), especially when you already -//! have to zip through both input lists to combine each key's values in some way. -//! -//! There is one remaining caveat: -//! -//! - You should construct lists in key order; doing this lets you insert each value in constant time. -//! Inserting entries in reverse order results in _quadratic_ overall time to construct the list. -//! -//! Lists are created using a [`ListBuilder`], and once created are accessed via a [`ListStorage`]. -//! -//! ## Tests -//! -//! This module contains quickcheck-based property tests. -//! -//! These tests are disabled by default, as they are non-deterministic and slow. You can run them -//! explicitly using: -//! -//! ```sh -//! cargo test -p ruff_index -- --ignored list::property_tests -//! ``` -//! -//! The number of tests (default: 100) can be controlled by setting the `QUICKCHECK_TESTS` -//! environment variable. For example: -//! -//! ```sh -//! QUICKCHECK_TESTS=10000 cargo test … -//! ``` -//! -//! If you want to run these tests for a longer period of time, it's advisable to run them in -//! release mode. As some tests are slower than others, it's advisable to run them in a loop until -//! they fail: -//! -//! ```sh -//! export QUICKCHECK_TESTS=100000 -//! while cargo test --release -p ruff_index -- \ -//! --ignored list::property_tests; do :; done -//! ``` -//! -//! [alist]: https://en.wikipedia.org/wiki/Association_list - -use std::cmp::Ordering; -use std::marker::PhantomData; -use std::ops::Deref; - -use ruff_index::{IndexVec, newtype_index}; - -/// A handle to an association list. Use [`ListStorage`] to access its elements, and -/// [`ListBuilder`] to construct other lists based on this one. -#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, get_size2::GetSize)] -pub(crate) struct List { - last: Option, - _phantom: PhantomData<(K, V)>, -} - -impl List { - pub(crate) const fn empty() -> List { - List::new(None) - } - - const fn new(last: Option) -> List { - List { - last, - _phantom: PhantomData, - } - } -} - -impl Default for List { - fn default() -> Self { - List::empty() - } -} - -#[newtype_index] -#[derive(PartialOrd, Ord, get_size2::GetSize)] -struct ListCellId; - -/// Stores one or more association lists. This type provides read-only access to the lists. Use a -/// [`ListBuilder`] to create lists. -#[derive(Debug, Eq, PartialEq, get_size2::GetSize)] -pub(crate) struct ListStorage { - cells: IndexVec>, -} - -/// Each association list is represented by a sequence of snoc cells. A snoc cell is like the more -/// familiar cons cell `(a : (b : (c : nil)))`, but in reverse `(((nil : a) : b) : c)`. -/// -/// **Terminology**: The elements of a cons cell are usually called `head` and `tail` (assuming -/// you're not in Lisp-land, where they're called `car` and `cdr`). The elements of a snoc cell -/// are usually called `rest` and `last`. -#[derive(Debug, Eq, PartialEq, get_size2::GetSize)] -struct ListCell { - rest: Option, - key: K, - value: V, -} - -/// Constructs one or more association lists. -#[derive(Debug, Eq, PartialEq)] -pub(crate) struct ListBuilder { - storage: ListStorage, - - /// Scratch space that lets us implement our list operations iteratively instead of - /// recursively. - /// - /// The snoc-list representation that we use for alists is very common in functional - /// programming, and the simplest implementations of most of the operations are defined - /// recursively on that data structure. However, they are not _tail_ recursive, which means - /// that the call stack grows linearly with the size of the input, which can be a problem for - /// large lists. - /// - /// You can often rework those recursive implementations into iterative ones using an - /// _accumulator_, but that comes at the cost of reversing the list. If we didn't care about - /// ordering, that wouldn't be a problem. Since we want our lists to be sorted, we can't rely - /// on that on its own. - /// - /// The next standard trick is to use an accumulator, and use a fix-up step at the end to - /// reverse the (reversed) result in the accumulator, restoring the correct order. - /// - /// So, that's what we do! However, as one last optimization, we don't build up alist cells in - /// our accumulator, since that would add wasteful cruft to our list storage. Instead, we use a - /// normal Vec as our accumulator, holding the key/value pairs that should be stitched onto the - /// end of whatever result list we are creating. For our fix-up step, we can consume a Vec in - /// reverse order by `pop`ping the elements off one by one. - scratch: Vec<(K, V)>, -} - -impl Default for ListBuilder { - fn default() -> Self { - ListBuilder { - storage: ListStorage { - cells: IndexVec::default(), - }, - scratch: Vec::default(), - } - } -} - -impl Deref for ListBuilder { - type Target = ListStorage; - fn deref(&self) -> &ListStorage { - &self.storage - } -} - -impl ListBuilder { - /// Finalizes a `ListBuilder`. After calling this, you cannot create any new lists managed by - /// this storage. - pub(crate) fn build(mut self) -> ListStorage { - self.storage.cells.shrink_to_fit(); - self.storage - } - - /// Adds a new cell to the list. - /// - /// Adding an element always returns a non-empty list, which means we could technically use `I` - /// as our return type, since we never return `None`. However, for consistency with our other - /// methods, we always use `Option` as the return type for any method that can return a - /// list. - #[expect(clippy::unnecessary_wraps)] - fn add_cell(&mut self, rest: Option, key: K, value: V) -> Option { - Some(self.storage.cells.push(ListCell { rest, key, value })) - } - - /// Returns an entry pointing at where `key` would be inserted into a list. - /// - /// Note that when we add a new element to a list, we might have to clone the keys and values - /// of some existing elements. This is because list cells are immutable once created, since - /// they might be shared across multiple lists. We must therefore create new cells for every - /// element that appears after the new element. - /// - /// That means that you should construct lists in key order, since that means that there are no - /// entries to duplicate for each insertion. If you construct the list in reverse order, we - /// will have to duplicate O(n) entries for each insertion, making it _quadratic_ to construct - /// the entire list. - pub(crate) fn entry(&mut self, list: List, key: K) -> ListEntry<'_, K, V> - where - K: Clone + Ord, - V: Clone, - { - self.scratch.clear(); - - // Iterate through the input list, looking for the position where the key should be - // inserted. We will need to create new list cells for any elements that appear after the - // new key. Stash those away in our scratch accumulator as we step through the input. The - // result of the loop is that "rest" of the result list, which we will stitch the new key - // (and any succeeding keys) onto. - let mut curr = list.last; - while let Some(curr_id) = curr { - let cell = &self.storage.cells[curr_id]; - match key.cmp(&cell.key) { - // We found an existing entry in the input list with the desired key. - Ordering::Equal => { - return ListEntry { - builder: self, - list, - key, - rest: ListTail::Occupied(curr_id), - }; - } - // The input list does not already contain this key, and this is where we should - // add it. - Ordering::Greater => { - return ListEntry { - builder: self, - list, - key, - rest: ListTail::Vacant(curr_id), - }; - } - // If this key is in the list, it's further along. We'll need to create a new cell - // for this entry in the result list, so add its contents to the scratch - // accumulator. - Ordering::Less => { - let new_key = cell.key.clone(); - let new_value = cell.value.clone(); - self.scratch.push((new_key, new_value)); - curr = cell.rest; - } - } - } - - // We made it all the way through the list without finding the desired key, so it belongs - // at the beginning. (And we will unfortunately have to duplicate every existing cell if - // the caller proceeds with inserting the new key!) - ListEntry { - builder: self, - list, - key, - rest: ListTail::Beginning, - } - } -} - -/// A view into a list, indicating where a key would be inserted. -pub(crate) struct ListEntry<'a, K, V = ()> { - builder: &'a mut ListBuilder, - list: List, - key: K, - /// Points at the element that already contains `key`, if there is one, or the element - /// immediately before where it would go, if not. - rest: ListTail, -} - -enum ListTail { - /// The list does not already contain `key`, and it would go at the beginning of the list. - Beginning, - /// The list already contains `key` - Occupied(I), - /// The list does not already contain key, and it would go immediately after the given element - Vacant(I), -} - -impl ListEntry<'_, K, V> -where - K: Clone, - V: Clone, -{ - fn stitch_up(self, rest: Option, value: V) -> List { - let mut last = rest; - last = self.builder.add_cell(last, self.key, value); - while let Some((key, value)) = self.builder.scratch.pop() { - last = self.builder.add_cell(last, key, value); - } - List::new(last) - } - - /// Inserts a new key/value into the list if the key is not already present. If the list - /// already contains `key`, we return the original list as-is, and do not invoke your closure. - pub(crate) fn or_insert_with(self, f: F) -> List - where - F: FnOnce() -> V, - { - let rest = match self.rest { - // If the list already contains `key`, we don't need to replace anything, and can - // return the original list unmodified. - ListTail::Occupied(_) => return self.list, - // Otherwise we have to create a new entry and stitch it onto the list. - ListTail::Beginning => None, - ListTail::Vacant(index) => Some(index), - }; - self.stitch_up(rest, f()) - } - - /// Inserts a new key and the default value into the list if the key is not already present. If - /// the list already contains `key`, we return the original list as-is. - pub(crate) fn or_insert_default(self) -> List - where - V: Default, - { - self.or_insert_with(V::default) - } -} - -impl ListBuilder { - /// Returns the intersection of two lists. The result will contain an entry for any key that - /// appears in both lists. The corresponding values will be combined using the `combine` - /// function that you provide. - #[expect(clippy::needless_pass_by_value)] - pub(crate) fn intersect_with( - &mut self, - a: List, - b: List, - mut combine: F, - ) -> List - where - K: Clone + Ord, - V: Clone, - F: FnMut(&V, &V) -> V, - { - self.scratch.clear(); - - // Zip through the lists, building up the keys/values of the new entries into our scratch - // vector. Continue until we run out of elements in either list. (Any remaining elements in - // the other list cannot possibly be in the intersection.) - let mut a = a.last; - let mut b = b.last; - while let (Some(a_id), Some(b_id)) = (a, b) { - let a_cell = &self.storage.cells[a_id]; - let b_cell = &self.storage.cells[b_id]; - match a_cell.key.cmp(&b_cell.key) { - // Both lists contain this key; combine their values - Ordering::Equal => { - let new_key = a_cell.key.clone(); - let new_value = combine(&a_cell.value, &b_cell.value); - self.scratch.push((new_key, new_value)); - a = a_cell.rest; - b = b_cell.rest; - } - // a's key is only present in a, so it's not included in the result. - Ordering::Greater => a = a_cell.rest, - // b's key is only present in b, so it's not included in the result. - Ordering::Less => b = b_cell.rest, - } - } - - // Once the iteration loop terminates, we stitch the new entries back together into proper - // alist cells. - let mut last = None; - while let Some((key, value)) = self.scratch.pop() { - last = self.add_cell(last, key, value); - } - List::new(last) - } -} - -// ---- -// Sets - -impl ListStorage { - /// Iterates through the elements in a set _in reverse order_. - #[expect(clippy::needless_pass_by_value)] - pub(crate) fn iter_set_reverse(&self, set: List) -> ListSetReverseIterator<'_, K> { - ListSetReverseIterator { - storage: self, - curr: set.last, - } - } -} - -pub(crate) struct ListSetReverseIterator<'a, K> { - storage: &'a ListStorage, - curr: Option, -} - -impl<'a, K> Iterator for ListSetReverseIterator<'a, K> { - type Item = &'a K; - - fn next(&mut self) -> Option { - let cell = &self.storage.cells[self.curr?]; - self.curr = cell.rest; - Some(&cell.key) - } -} - -impl ListBuilder { - /// Adds an element to a set. - pub(crate) fn insert(&mut self, set: List, element: K) -> List - where - K: Clone + Ord, - { - self.entry(set, element).or_insert_default() - } - - /// Returns the intersection of two sets. The result will contain any value that appears in - /// both sets. - pub(crate) fn intersect(&mut self, a: List, b: List) -> List - where - K: Clone + Ord, - { - self.intersect_with(a, b, |(), ()| ()) - } -} - -// ----- -// Tests - -#[cfg(test)] -mod tests { - use super::*; - - use std::fmt::Display; - use std::fmt::Write; - - // ---- - // Sets - - impl ListStorage - where - K: Display, - { - fn display_set(&self, list: List) -> String { - let elements: Vec<_> = self.iter_set_reverse(list).collect(); - let mut result = String::new(); - result.push('['); - for element in elements.into_iter().rev() { - if result.len() > 1 { - result.push_str(", "); - } - write!(&mut result, "{element}").unwrap(); - } - result.push(']'); - result - } - } - - #[test] - fn can_insert_into_set() { - let mut builder = ListBuilder::::default(); - - // Build up the set in order - let empty = List::empty(); - let set1 = builder.insert(empty, 1); - let set12 = builder.insert(set1, 2); - let set123 = builder.insert(set12, 3); - let set1232 = builder.insert(set123, 2); - assert_eq!(builder.display_set(empty), "[]"); - assert_eq!(builder.display_set(set1), "[1]"); - assert_eq!(builder.display_set(set12), "[1, 2]"); - assert_eq!(builder.display_set(set123), "[1, 2, 3]"); - assert_eq!(builder.display_set(set1232), "[1, 2, 3]"); - - // And in reverse order - let set3 = builder.insert(empty, 3); - let set32 = builder.insert(set3, 2); - let set321 = builder.insert(set32, 1); - let set3212 = builder.insert(set321, 2); - assert_eq!(builder.display_set(empty), "[]"); - assert_eq!(builder.display_set(set3), "[3]"); - assert_eq!(builder.display_set(set32), "[2, 3]"); - assert_eq!(builder.display_set(set321), "[1, 2, 3]"); - assert_eq!(builder.display_set(set3212), "[1, 2, 3]"); - } - - #[test] - fn can_intersect_sets() { - let mut builder = ListBuilder::::default(); - - let empty = List::empty(); - let set1 = builder.insert(empty, 1); - let set12 = builder.insert(set1, 2); - let set123 = builder.insert(set12, 3); - let set1234 = builder.insert(set123, 4); - - let set2 = builder.insert(empty, 2); - let set24 = builder.insert(set2, 4); - let set245 = builder.insert(set24, 5); - let set2457 = builder.insert(set245, 7); - - let intersection = builder.intersect(empty, empty); - assert_eq!(builder.display_set(intersection), "[]"); - let intersection = builder.intersect(empty, set1234); - assert_eq!(builder.display_set(intersection), "[]"); - let intersection = builder.intersect(empty, set2457); - assert_eq!(builder.display_set(intersection), "[]"); - let intersection = builder.intersect(set1, set1234); - assert_eq!(builder.display_set(intersection), "[1]"); - let intersection = builder.intersect(set1, set2457); - assert_eq!(builder.display_set(intersection), "[]"); - let intersection = builder.intersect(set2, set1234); - assert_eq!(builder.display_set(intersection), "[2]"); - let intersection = builder.intersect(set2, set2457); - assert_eq!(builder.display_set(intersection), "[2]"); - let intersection = builder.intersect(set1234, set2457); - assert_eq!(builder.display_set(intersection), "[2, 4]"); - } - - // ---- - // Maps - - impl ListStorage { - /// Iterates through the entries in a list _in reverse order by key_. - #[expect(clippy::needless_pass_by_value)] - pub(crate) fn iter_reverse(&self, list: List) -> ListReverseIterator<'_, K, V> { - ListReverseIterator { - storage: self, - curr: list.last, - } - } - } - - pub(crate) struct ListReverseIterator<'a, K, V> { - storage: &'a ListStorage, - curr: Option, - } - - impl<'a, K, V> Iterator for ListReverseIterator<'a, K, V> { - type Item = (&'a K, &'a V); - - fn next(&mut self) -> Option { - let cell = &self.storage.cells[self.curr?]; - self.curr = cell.rest; - Some((&cell.key, &cell.value)) - } - } - - impl ListStorage - where - K: Display, - V: Display, - { - fn display(&self, list: List) -> String { - let entries: Vec<_> = self.iter_reverse(list).collect(); - let mut result = String::new(); - result.push('['); - for (key, value) in entries.into_iter().rev() { - if result.len() > 1 { - result.push_str(", "); - } - write!(&mut result, "{key}:{value}").unwrap(); - } - result.push(']'); - result - } - } - - #[test] - fn can_insert_into_map() { - let mut builder = ListBuilder::::default(); - - // Build up the map in order - let empty = List::empty(); - let map1 = builder.entry(empty, 1).or_insert_with(|| 1); - let map12 = builder.entry(map1, 2).or_insert_with(|| 2); - let map123 = builder.entry(map12, 3).or_insert_with(|| 3); - let map1232 = builder.entry(map123, 2).or_insert_with(|| 4); - assert_eq!(builder.display(empty), "[]"); - assert_eq!(builder.display(map1), "[1:1]"); - assert_eq!(builder.display(map12), "[1:1, 2:2]"); - assert_eq!(builder.display(map123), "[1:1, 2:2, 3:3]"); - assert_eq!(builder.display(map1232), "[1:1, 2:2, 3:3]"); - - // And in reverse order - let map3 = builder.entry(empty, 3).or_insert_with(|| 3); - let map32 = builder.entry(map3, 2).or_insert_with(|| 2); - let map321 = builder.entry(map32, 1).or_insert_with(|| 1); - let map3212 = builder.entry(map321, 2).or_insert_with(|| 4); - assert_eq!(builder.display(empty), "[]"); - assert_eq!(builder.display(map3), "[3:3]"); - assert_eq!(builder.display(map32), "[2:2, 3:3]"); - assert_eq!(builder.display(map321), "[1:1, 2:2, 3:3]"); - assert_eq!(builder.display(map3212), "[1:1, 2:2, 3:3]"); - } - - #[test] - fn can_intersect_maps() { - let mut builder = ListBuilder::::default(); - - let empty = List::empty(); - let map1 = builder.entry(empty, 1).or_insert_with(|| 1); - let map12 = builder.entry(map1, 2).or_insert_with(|| 2); - let map123 = builder.entry(map12, 3).or_insert_with(|| 3); - let map1234 = builder.entry(map123, 4).or_insert_with(|| 4); - - let map2 = builder.entry(empty, 2).or_insert_with(|| 20); - let map24 = builder.entry(map2, 4).or_insert_with(|| 40); - let map245 = builder.entry(map24, 5).or_insert_with(|| 50); - let map2457 = builder.entry(map245, 7).or_insert_with(|| 70); - - let intersection = builder.intersect_with(empty, empty, |a, b| a + b); - assert_eq!(builder.display(intersection), "[]"); - let intersection = builder.intersect_with(empty, map1234, |a, b| a + b); - assert_eq!(builder.display(intersection), "[]"); - let intersection = builder.intersect_with(empty, map2457, |a, b| a + b); - assert_eq!(builder.display(intersection), "[]"); - let intersection = builder.intersect_with(map1, map1234, |a, b| a + b); - assert_eq!(builder.display(intersection), "[1:2]"); - let intersection = builder.intersect_with(map1, map2457, |a, b| a + b); - assert_eq!(builder.display(intersection), "[]"); - let intersection = builder.intersect_with(map2, map1234, |a, b| a + b); - assert_eq!(builder.display(intersection), "[2:22]"); - let intersection = builder.intersect_with(map2, map2457, |a, b| a + b); - assert_eq!(builder.display(intersection), "[2:40]"); - let intersection = builder.intersect_with(map1234, map2457, |a, b| a + b); - assert_eq!(builder.display(intersection), "[2:22, 4:44]"); - } -} - -// -------------- -// Property tests - -#[cfg(test)] -mod property_tests { - use super::*; - - use std::collections::{BTreeMap, BTreeSet}; - - impl ListBuilder - where - K: Clone + Ord, - { - fn set_from_elements<'a>(&mut self, elements: impl IntoIterator) -> List - where - K: 'a, - { - let mut set = List::empty(); - for element in elements { - set = self.insert(set, element.clone()); - } - set - } - } - - // For most of the tests below, we use a vec as our input, instead of a HashSet or BTreeSet, - // since we want to test the behavior of adding duplicate elements to the set. - - #[quickcheck_macros::quickcheck] - #[ignore] - #[expect(clippy::needless_pass_by_value)] - fn roundtrip_set_from_vec(elements: Vec) -> bool { - let mut builder = ListBuilder::default(); - let set = builder.set_from_elements(&elements); - let expected: BTreeSet<_> = elements.iter().copied().collect(); - let actual = builder.iter_set_reverse(set).copied(); - actual.eq(expected.into_iter().rev()) - } - - #[quickcheck_macros::quickcheck] - #[ignore] - #[expect(clippy::needless_pass_by_value)] - fn roundtrip_set_intersection(a_elements: Vec, b_elements: Vec) -> bool { - let mut builder = ListBuilder::default(); - let a = builder.set_from_elements(&a_elements); - let b = builder.set_from_elements(&b_elements); - let intersection = builder.intersect(a, b); - let a_set: BTreeSet<_> = a_elements.iter().copied().collect(); - let b_set: BTreeSet<_> = b_elements.iter().copied().collect(); - let expected: Vec<_> = a_set.intersection(&b_set).copied().collect(); - let actual = builder.iter_set_reverse(intersection).copied(); - actual.eq(expected.into_iter().rev()) - } - - impl ListBuilder - where - K: Clone + Ord, - V: Clone + Eq, - { - fn set_from_pairs<'a, I>(&mut self, pairs: I) -> List - where - K: 'a, - V: 'a, - I: IntoIterator, - I::IntoIter: DoubleEndedIterator, - { - let mut list = List::empty(); - for (key, value) in pairs.into_iter().rev() { - list = self - .entry(list, key.clone()) - .or_insert_with(|| value.clone()); - } - list - } - } - - fn join(a: &BTreeMap, b: &BTreeMap) -> BTreeMap, Option)> - where - K: Clone + Ord, - V: Clone + Ord, - { - let mut joined: BTreeMap, Option)> = BTreeMap::new(); - for (k, v) in a { - joined.entry(k.clone()).or_default().0 = Some(v.clone()); - } - for (k, v) in b { - joined.entry(k.clone()).or_default().1 = Some(v.clone()); - } - joined - } - - #[quickcheck_macros::quickcheck] - #[ignore] - #[expect(clippy::needless_pass_by_value)] - fn roundtrip_list_from_vec(pairs: Vec<(u16, u16)>) -> bool { - let mut builder = ListBuilder::default(); - let list = builder.set_from_pairs(&pairs); - let expected: BTreeMap<_, _> = pairs.iter().copied().collect(); - let actual = builder.iter_reverse(list).map(|(k, v)| (*k, *v)); - actual.eq(expected.into_iter().rev()) - } - - #[quickcheck_macros::quickcheck] - #[ignore] - #[expect(clippy::needless_pass_by_value)] - fn roundtrip_list_intersection( - a_elements: Vec<(u16, u16)>, - b_elements: Vec<(u16, u16)>, - ) -> bool { - let mut builder = ListBuilder::default(); - let a = builder.set_from_pairs(&a_elements); - let b = builder.set_from_pairs(&b_elements); - let intersection = builder.intersect_with(a, b, |a, b| a + b); - let a_map: BTreeMap<_, _> = a_elements.iter().copied().collect(); - let b_map: BTreeMap<_, _> = b_elements.iter().copied().collect(); - let intersection_map = join(&a_map, &b_map); - let expected: Vec<_> = intersection_map - .into_iter() - .filter_map(|(k, (v1, v2))| Some((k, v1? + v2?))) - .collect(); - let actual = builder.iter_reverse(intersection).map(|(k, v)| (*k, *v)); - actual.eq(expected.into_iter().rev()) - } -} diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 6a9923af35f02..436c2963f2efc 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -815,10 +815,10 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { fn record_expression_narrowing_constraint( &mut self, predicate_node: &ast::Expr, - ) -> PredicateOrLiteral<'db> { + ) -> (PredicateOrLiteral<'db>, ScopedPredicateId) { let predicate = self.build_predicate(predicate_node); - self.record_narrowing_constraint(predicate); - predicate + let predicate_id = self.record_narrowing_constraint(predicate); + (predicate, predicate_id) } fn build_predicate(&mut self, predicate_node: &ast::Expr) -> PredicateOrLiteral<'db> { @@ -881,11 +881,18 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { } /// Adds and records a narrowing constraint for only the places that could possibly be narrowed. - fn record_narrowing_constraint(&mut self, predicate: PredicateOrLiteral<'db>) { + /// + /// Returns the `ScopedPredicateId` for the positive predicate, which can later be passed to + /// `record_negated_narrowing_constraint` for TDD-level negation. + fn record_narrowing_constraint( + &mut self, + predicate: PredicateOrLiteral<'db>, + ) -> ScopedPredicateId { let possibly_narrowed = self.compute_possibly_narrowed_places(&predicate); let use_def = self.current_use_def_map_mut(); let predicate_id = use_def.add_predicate(predicate); use_def.record_narrowing_constraint_for_places(predicate_id, &possibly_narrowed); + predicate_id } /// Computes the conservative set of places that could possibly be narrowed by a predicate. @@ -926,14 +933,18 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { /// Negates the given predicate and then adds it as a narrowing constraint to the places /// that could possibly be narrowed. + /// + /// Takes the `ScopedPredicateId` from the positive recording so that TDD-level negation + /// (`add_not_constraint`) uses the same atom. This ensures `atom(P) OR NOT(atom(P))` + /// simplifies to `ALWAYS_TRUE`, correctly cancelling narrowing after complete if/else blocks. fn record_negated_narrowing_constraint( &mut self, predicate: PredicateOrLiteral<'db>, - ) -> ScopedPredicateId { + predicate_id: ScopedPredicateId, + ) { let possibly_narrowed = self.compute_possibly_narrowed_places(&predicate); - let id = self.add_negated_predicate(predicate); - self.record_narrowing_constraint_id_for_places(id, &possibly_narrowed); - id + self.current_use_def_map_mut() + .record_negated_narrowing_constraint_for_places(predicate_id, &possibly_narrowed); } /// Records that all remaining statements in the current block are unreachable. @@ -1058,7 +1069,12 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { pattern: &ast::Pattern, guard: Option<&ast::Expr>, previous_pattern: Option>, - ) -> (PredicateOrLiteral<'db>, PatternPredicate<'db>) { + is_catchall: bool, + ) -> ( + PredicateOrLiteral<'db>, + ScopedPredicateId, + PatternPredicate<'db>, + ) { // This is called for the top-level pattern of each match arm. We need to create a // standalone expression for each arm of a match statement, since they can introduce // constraints on the match subject. (Or more accurately, for the match arm's pattern, @@ -1082,12 +1098,24 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { guard, previous_pattern.map(Box::new), ); + let predicate = PredicateOrLiteral::Predicate(Predicate { node: PredicateNode::Pattern(pattern_predicate), is_positive: true, }); - self.record_narrowing_constraint(predicate); - (predicate, pattern_predicate) + + // For the last catchall case (irrefutable wildcard without guard), we skip + // recording the narrowing constraint from the pattern. The accumulated negated + // constraints from earlier cases (~P1, ~P2, ...) are sufficient. This ensures + // `P1 OR (~P1 AND P2) OR (~P1 AND ~P2)` simplifies to ALWAYS_TRUE, preserving + // the original type after an exhaustive match. The reachability and pattern + // predicates are still created normally for proper control flow tracking. + let predicate_id = if is_catchall { + ScopedPredicateId::ALWAYS_TRUE + } else { + self.record_narrowing_constraint(predicate) + }; + (predicate, predicate_id, pattern_predicate) } /// Record an expression that needs to be a Salsa ingredient, because we need to infer its type @@ -1234,7 +1262,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { for if_expr in &generator.ifs { self.visit_expr(if_expr); - self.record_expression_narrowing_constraint(if_expr); + let _ = self.record_expression_narrowing_constraint(if_expr); } for generator in generators_iter { @@ -1252,7 +1280,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { for if_expr in &generator.ifs { self.visit_expr(if_expr); - self.record_expression_narrowing_constraint(if_expr); + let _ = self.record_expression_narrowing_constraint(if_expr); } } @@ -1940,7 +1968,8 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { ast::Stmt::If(node) => { self.visit_expr(&node.test); let mut no_branch_taken = self.flow_snapshot(); - let mut last_predicate = self.record_expression_narrowing_constraint(&node.test); + let (mut last_predicate, mut last_narrowing_id) = + self.record_expression_narrowing_constraint(&node.test); let mut last_reachability_constraint = self.record_reachability_constraint(last_predicate); @@ -1981,7 +2010,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { // taken self.flow_restore(no_branch_taken.clone()); - self.record_negated_narrowing_constraint(last_predicate); + self.record_negated_narrowing_constraint(last_predicate, last_narrowing_id); self.record_negated_reachability_constraint(last_reachability_constraint); if let Some(elif_test) = clause_test { @@ -1989,7 +2018,8 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { // A test expression is evaluated whether the branch is taken or not no_branch_taken = self.flow_snapshot(); - last_predicate = self.record_expression_narrowing_constraint(elif_test); + (last_predicate, last_narrowing_id) = + self.record_expression_narrowing_constraint(elif_test); last_reachability_constraint = self.record_reachability_constraint(last_predicate); @@ -2034,7 +2064,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { self.visit_expr(test); let pre_loop = self.flow_snapshot(); - let predicate = self.record_expression_narrowing_constraint(test); + let (predicate, predicate_id) = self.record_expression_narrowing_constraint(test); self.record_reachability_constraint(predicate); let outer_loop = self.push_loop(); @@ -2058,7 +2088,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { .add_atom(later_predicate_id); self.record_negated_reachability_constraint(later_reachability_constraint); - self.record_negated_narrowing_constraint(predicate); + self.record_negated_narrowing_constraint(predicate, predicate_id); self.visit_body(orelse); @@ -2165,12 +2195,14 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { // here because the effects of visiting a pattern is binding // symbols, and this doesn't occur unless the pattern // actually matches - let (match_predicate, match_pattern_predicate) = self + let is_catchall = has_catchall && i == cases.len() - 1; + let (match_predicate, match_narrowing_id, match_pattern_predicate) = self .add_pattern_narrowing_constraint( subject_expr, &case.pattern, case.guard.as_deref(), previous_pattern, + is_catchall, ); previous_pattern = Some(match_pattern_predicate); let reachability_constraint = @@ -2188,10 +2220,22 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { node: PredicateNode::Expression(guard_expr), is_positive: true, }); - self.record_negated_narrowing_constraint(predicate); + // Add the predicate once, then use TDD-level negation for the failure + // path. This ensures the positive and negative atoms share the same ID. + let guard_predicate_id = self.add_predicate(predicate); + let possibly_narrowed = self.compute_possibly_narrowed_places(&predicate); + self.current_use_def_map_mut() + .record_negated_narrowing_constraint_for_places( + guard_predicate_id, + &possibly_narrowed, + ); let match_success_guard_failure = self.flow_snapshot(); self.flow_restore(post_guard_eval); - self.record_narrowing_constraint(predicate); + self.current_use_def_map_mut() + .record_narrowing_constraint_for_places( + guard_predicate_id, + &possibly_narrowed, + ); match_success_guard_failure }); @@ -2204,7 +2248,10 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { // one. The last one will just become the state that we merge the other // snapshots into. self.flow_restore(no_case_matched.clone()); - self.record_negated_narrowing_constraint(match_predicate); + self.record_negated_narrowing_constraint( + match_predicate, + match_narrowing_id, + ); self.record_negated_reachability_constraint(reachability_constraint); if let Some(match_success_guard_failure) = match_success_guard_failure { self.flow_merge(match_success_guard_failure); @@ -2508,9 +2555,17 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { }), is_positive: false, }; - self.record_reachability_constraint(PredicateOrLiteral::Predicate( - predicate, - )); + let constraint = self.record_reachability_constraint( + PredicateOrLiteral::Predicate(predicate), + ); + + // Also gate narrowing by this constraint: if the call returns + // `Never`, any narrowing in the current branch should be + // invalidated (since this path is unreachable). This enables + // narrowing to be preserved after if-statements where one branch + // calls a `NoReturn` function like `sys.exit()`. + self.current_use_def_map_mut() + .record_narrowing_constraint_for_all_places(constraint); } } } @@ -2729,13 +2784,13 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { }) => { self.visit_expr(test); let pre_if = self.flow_snapshot(); - let predicate = self.record_expression_narrowing_constraint(test); + let (predicate, predicate_id) = self.record_expression_narrowing_constraint(test); let reachability_constraint = self.record_reachability_constraint(predicate); self.visit_expr(body); let post_body = self.flow_snapshot(); self.flow_restore(pre_if); - self.record_negated_narrowing_constraint(predicate); + self.record_negated_narrowing_constraint(predicate, predicate_id); self.record_negated_reachability_constraint(reachability_constraint); self.visit_expr(orelse); self.flow_merge(post_body); diff --git a/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs b/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs index 8d27dd1e201af..8a5aa2ee61a5d 100644 --- a/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs @@ -5,40 +5,22 @@ //! associated with a different narrowing constraint at different points in a file. See the //! [`use_def`][crate::semantic_index::use_def] module for more details. //! -//! This module defines how narrowing constraints are stored internally. -//! -//! A _narrowing constraint_ consists of a list of _predicates_, each of which corresponds with an -//! expression in the source file (represented by a [`Predicate`]). We need to support the -//! following operations on narrowing constraints: -//! -//! - Adding a new predicate to an existing constraint -//! - Merging two constraints together, which produces the _intersection_ of their predicates -//! - Iterating through the predicates in a constraint -//! -//! In particular, note that we do not need random access to the predicates in a constraint. That -//! means that we can use a simple [_sorted association list_][crate::list] as our data structure. -//! That lets us use a single 32-bit integer to store each narrowing constraint, no matter how many -//! predicates it contains. It also makes merging two narrowing constraints fast, since alists -//! support fast intersection. -//! -//! Because we visit the contents of each scope in source-file order, and assign scoped IDs in -//! source-file order, that means that we will tend to visit narrowing constraints in order by -//! their predicate IDs. This is exactly how to get the best performance from our alist -//! implementation. +//! Narrowing constraints are represented as TDD (ternary decision diagram) nodes, sharing the +//! same graph as reachability constraints. This allows narrowing constraints to support AND, OR, +//! and NOT operations, which is essential for correctly preserving narrowing information across +//! control flow merges (e.g. after if/elif/else with terminal branches). //! //! [`Predicate`]: crate::semantic_index::predicate::Predicate -use crate::list::{List, ListBuilder, ListSetReverseIterator, ListStorage}; use crate::semantic_index::ast_ids::ScopedUseId; -use crate::semantic_index::predicate::ScopedPredicateId; +use crate::semantic_index::reachability_constraints::ScopedReachabilityConstraintId; use crate::semantic_index::scope::FileScopeId; /// A narrowing constraint associated with a live binding. /// -/// A constraint is a list of [`Predicate`]s that each constrain the type of the binding's place. -/// -/// [`Predicate`]: crate::semantic_index::predicate::Predicate -pub(crate) type ScopedNarrowingConstraint = List; +/// This is a TDD node ID in the shared reachability constraints graph. +/// `ALWAYS_TRUE` means "no narrowing constraint" (the base type is unchanged). +pub(crate) type ScopedNarrowingConstraint = ScopedReachabilityConstraintId; #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum ConstraintKey { @@ -46,108 +28,3 @@ pub(crate) enum ConstraintKey { NestedScope(FileScopeId), UseId(ScopedUseId), } - -/// One of the [`Predicate`]s in a narrowing constraint, which constraints the type of the -/// binding's place. -/// -/// Note that those [`Predicate`]s are stored in [their own per-scope -/// arena][crate::semantic_index::predicate::Predicates], so internally we use a -/// [`ScopedPredicateId`] to refer to the underlying predicate. -/// -/// [`Predicate`]: crate::semantic_index::predicate::Predicate -#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, get_size2::GetSize)] -pub(crate) struct ScopedNarrowingConstraintPredicate(ScopedPredicateId); - -impl ScopedNarrowingConstraintPredicate { - /// Returns (the ID of) the `Predicate` - pub(crate) fn predicate(self) -> ScopedPredicateId { - self.0 - } -} - -impl From for ScopedNarrowingConstraintPredicate { - fn from(predicate: ScopedPredicateId) -> ScopedNarrowingConstraintPredicate { - ScopedNarrowingConstraintPredicate(predicate) - } -} - -/// A collection of narrowing constraints for a given scope. -#[derive(Debug, Eq, PartialEq, get_size2::GetSize)] -pub(crate) struct NarrowingConstraints { - lists: ListStorage, -} - -// Building constraints -// -------------------- - -/// A builder for creating narrowing constraints. -#[derive(Debug, Default, Eq, PartialEq)] -pub(crate) struct NarrowingConstraintsBuilder { - lists: ListBuilder, -} - -impl NarrowingConstraintsBuilder { - pub(crate) fn build(self) -> NarrowingConstraints { - NarrowingConstraints { - lists: self.lists.build(), - } - } - - /// Adds a predicate to an existing narrowing constraint. - pub(crate) fn add_predicate_to_constraint( - &mut self, - constraint: ScopedNarrowingConstraint, - predicate: ScopedNarrowingConstraintPredicate, - ) -> ScopedNarrowingConstraint { - self.lists.insert(constraint, predicate) - } - - /// Returns the intersection of two narrowing constraints. The result contains the predicates - /// that appear in both inputs. - pub(crate) fn intersect_constraints( - &mut self, - a: ScopedNarrowingConstraint, - b: ScopedNarrowingConstraint, - ) -> ScopedNarrowingConstraint { - self.lists.intersect(a, b) - } -} - -// Iteration -// --------- - -pub(crate) type NarrowingConstraintsIterator<'a> = - std::iter::Copied>; - -impl NarrowingConstraints { - /// Iterates over the predicates in a narrowing constraint. - pub(crate) fn iter_predicates( - &self, - set: ScopedNarrowingConstraint, - ) -> NarrowingConstraintsIterator<'_> { - self.lists.iter_set_reverse(set).copied() - } -} - -// Test support -// ------------ - -#[cfg(test)] -mod tests { - use super::*; - - impl ScopedNarrowingConstraintPredicate { - pub(crate) fn as_u32(self) -> u32 { - self.0.as_u32() - } - } - - impl NarrowingConstraintsBuilder { - pub(crate) fn iter_predicates( - &self, - set: ScopedNarrowingConstraint, - ) -> NarrowingConstraintsIterator<'_> { - self.lists.iter_set_reverse(set).copied() - } - } -} diff --git a/crates/ty_python_semantic/src/semantic_index/predicate.rs b/crates/ty_python_semantic/src/semantic_index/predicate.rs index abefcc34b46b8..cb0519e6ca674 100644 --- a/crates/ty_python_semantic/src/semantic_index/predicate.rs +++ b/crates/ty_python_semantic/src/semantic_index/predicate.rs @@ -33,11 +33,6 @@ impl ScopedPredicateId { fn is_terminal(self) -> bool { self >= Self::SMALLEST_TERMINAL } - - #[cfg(test)] - pub(crate) fn as_u32(self) -> u32 { - self.0 - } } impl Idx for ScopedPredicateId { diff --git a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs index 74923ad97ae38..5761e95249be1 100644 --- a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs @@ -202,14 +202,15 @@ use crate::Db; use crate::dunder_all::dunder_all_names; use crate::place::{RequiresExplicitReExport, imported_symbol}; use crate::rank::RankBitBox; +use crate::semantic_index::place::ScopedPlaceId; use crate::semantic_index::place_table; use crate::semantic_index::predicate::{ CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, Predicates, ScopedPredicateId, }; use crate::types::{ - CallableTypes, IntersectionBuilder, Truthiness, Type, TypeContext, UnionBuilder, UnionType, - infer_expression_type, + CallableTypes, IntersectionBuilder, NarrowingConstraint, Truthiness, Type, TypeContext, + UnionBuilder, UnionType, infer_expression_type, infer_narrowing_constraint, }; /// A ternary formula that defines under what conditions a binding is visible. (A ternary formula @@ -734,7 +735,186 @@ impl ReachabilityConstraintsBuilder { } } +/// AND a new optional narrowing constraint with an accumulated one. +fn accumulate_constraint<'db>( + db: &'db dyn Db, + accumulated: Option>, + new: Option>, +) -> Option> { + match (accumulated, new) { + (Some(acc), Some(new_c)) => Some(new_c.merge_constraint_and(acc, db)), + (None, Some(new_c)) => Some(new_c), + (Some(acc), None) => Some(acc), + (None, None) => None, + } +} + impl ReachabilityConstraints { + /// Look up an interior node by its constraint ID. + fn get_interior_node(&self, id: ScopedReachabilityConstraintId) -> InteriorNode { + debug_assert!(!id.is_terminal()); + let raw_index = id.as_u32() as usize; + debug_assert!( + self.used_indices.get_bit(raw_index).unwrap_or(false), + "all used reachability constraints should have been marked as used", + ); + let index = self.used_indices.rank(raw_index) as usize; + self.used_interiors[index] + } + + /// Narrow a type by walking a TDD narrowing constraint. + /// + /// The TDD represents a ternary formula over predicates that encodes which predicates + /// hold along a particular control flow path. We walk from root to leaves, accumulating + /// narrowing constraints. + /// + /// At each interior node, we branch based on whether the predicate is true or false: + /// - True branch: apply positive narrowing from the predicate + /// - False branch: apply negative narrowing from the predicate + /// + /// The "ambiguous" branch in the TDD is not followed for narrowing purposes, because + /// narrowing constraints record which predicates hold along the control flow path. + /// The predicates may be statically ambiguous (we can't determine their truthiness + /// at analysis time), but they still hold dynamically at runtime and should be used + /// for narrowing. + /// + /// At leaves: + /// - `ALWAYS_TRUE` or `AMBIGUOUS`: apply all accumulated narrowing to the base type + /// - `ALWAYS_FALSE`: this path is impossible → Never + /// + /// The final result is the union of all path results. + pub(crate) fn narrow_by_constraint<'db>( + &self, + db: &'db dyn Db, + predicates: &Predicates<'db>, + id: ScopedReachabilityConstraintId, + base_ty: Type<'db>, + place: ScopedPlaceId, + ) -> Type<'db> { + self.narrow_by_constraint_inner(db, predicates, id, base_ty, place, None) + } + + /// Inner recursive helper that accumulates narrowing constraints along each TDD path. + fn narrow_by_constraint_inner<'db>( + &self, + db: &'db dyn Db, + predicates: &Predicates<'db>, + id: ScopedReachabilityConstraintId, + base_ty: Type<'db>, + place: ScopedPlaceId, + accumulated: Option>, + ) -> Type<'db> { + match id { + ALWAYS_TRUE | AMBIGUOUS => { + // Apply all accumulated narrowing constraints to the base type + match accumulated { + Some(constraint) => NarrowingConstraint::intersection(base_ty) + .merge_constraint_and(constraint, db) + .evaluate_constraint_type(db), + None => base_ty, + } + } + ALWAYS_FALSE => Type::Never, + _ => { + let node = self.get_interior_node(id); + let predicate = predicates[node.atom]; + + // `ReturnsNever` predicates don't narrow any variable; they only + // affect reachability. Evaluate the predicate to determine which + // path(s) are reachable, rather than walking both branches. + // `ReturnsNever` always evaluates to `AlwaysTrue` or `AlwaysFalse`, + // never `Ambiguous`. + if matches!(predicate.node, PredicateNode::ReturnsNever(_)) { + return match Self::analyze_single(db, &predicate) { + Truthiness::AlwaysTrue => self.narrow_by_constraint_inner( + db, + predicates, + node.if_true, + base_ty, + place, + accumulated, + ), + Truthiness::AlwaysFalse => self.narrow_by_constraint_inner( + db, + predicates, + node.if_false, + base_ty, + place, + accumulated, + ), + Truthiness::Ambiguous => { + unreachable!("ReturnsNever predicates should never be Ambiguous") + } + }; + } + + // Check if this predicate narrows the variable we're interested in. + let pos_constraint = infer_narrowing_constraint(db, predicate, place); + + // If the true branch is statically unreachable, skip it entirely. + if node.if_true == ALWAYS_FALSE { + let neg_predicate = Predicate { + node: predicate.node, + is_positive: !predicate.is_positive, + }; + let neg_constraint = infer_narrowing_constraint(db, neg_predicate, place); + let false_accumulated = accumulate_constraint(db, accumulated, neg_constraint); + return self.narrow_by_constraint_inner( + db, + predicates, + node.if_false, + base_ty, + place, + false_accumulated, + ); + } + + // If the false branch is statically unreachable, skip it entirely. + if node.if_false == ALWAYS_FALSE { + let true_accumulated = accumulate_constraint(db, accumulated, pos_constraint); + return self.narrow_by_constraint_inner( + db, + predicates, + node.if_true, + base_ty, + place, + true_accumulated, + ); + } + + // True branch: predicate holds → accumulate positive narrowing + let true_accumulated = + accumulate_constraint(db, accumulated.clone(), pos_constraint); + let true_ty = self.narrow_by_constraint_inner( + db, + predicates, + node.if_true, + base_ty, + place, + true_accumulated, + ); + + // False branch: predicate doesn't hold → accumulate negative narrowing + let neg_predicate = Predicate { + node: predicate.node, + is_positive: !predicate.is_positive, + }; + let neg_constraint = infer_narrowing_constraint(db, neg_predicate, place); + let false_accumulated = accumulate_constraint(db, accumulated, neg_constraint); + let false_ty = self.narrow_by_constraint_inner( + db, + predicates, + node.if_false, + base_ty, + place, + false_accumulated, + ); + + UnionType::from_elements(db, [true_ty, false_ty]) + } + } + } + /// Analyze the statically known reachability for a given constraint. pub(crate) fn evaluate<'db>( &self, diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index e32687f52d600..4a32870bafa4a 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -248,13 +248,10 @@ use crate::place::BoundnessAnalysis; use crate::semantic_index::ast_ids::ScopedUseId; use crate::semantic_index::definition::{Definition, DefinitionState}; use crate::semantic_index::member::ScopedMemberId; -use crate::semantic_index::narrowing_constraints::{ - ConstraintKey, NarrowingConstraints, NarrowingConstraintsBuilder, NarrowingConstraintsIterator, - ScopedNarrowingConstraint, -}; +use crate::semantic_index::narrowing_constraints::{ConstraintKey, ScopedNarrowingConstraint}; use crate::semantic_index::place::{PlaceExprRef, ScopedPlaceId}; use crate::semantic_index::predicate::{ - Predicate, PredicateOrLiteral, Predicates, PredicatesBuilder, ScopedPredicateId, + PredicateOrLiteral, Predicates, PredicatesBuilder, ScopedPredicateId, }; use crate::semantic_index::reachability_constraints::{ ReachabilityConstraints, ReachabilityConstraintsBuilder, ScopedReachabilityConstraintId, @@ -266,9 +263,7 @@ use crate::semantic_index::use_def::place_state::{ LiveDeclarationsIterator, PlaceState, PreviousDefinitions, ScopedDefinitionId, }; use crate::semantic_index::{EnclosingSnapshotResult, SemanticIndex}; -use crate::types::{ - NarrowingConstraint, PossiblyNarrowedPlaces, Truthiness, Type, infer_narrowing_constraint, -}; +use crate::types::{PossiblyNarrowedPlaces, Truthiness, Type}; mod place_state; @@ -282,9 +277,6 @@ pub(crate) struct UseDefMap<'db> { /// Array of predicates in this scope. predicates: Predicates<'db>, - /// Array of narrowing constraints in this scope. - narrowing_constraints: NarrowingConstraints, - /// Array of reachability constraints in this scope. reachability_constraints: ReachabilityConstraints, @@ -351,7 +343,7 @@ pub(crate) struct UseDefMap<'db> { } pub(crate) enum ApplicableConstraints<'map, 'db> { - UnboundBinding(ConstraintsIterator<'map, 'db>), + UnboundBinding(NarrowingEvaluator<'map, 'db>), ConstrainedBindings(BindingWithConstraintsIterator<'map, 'db>), } @@ -375,9 +367,10 @@ impl<'db> UseDefMap<'db> { ) -> ApplicableConstraints<'_, 'db> { match constraint_key { ConstraintKey::NarrowingConstraint(constraint) => { - ApplicableConstraints::UnboundBinding(ConstraintsIterator { + ApplicableConstraints::UnboundBinding(NarrowingEvaluator { + constraint, predicates: &self.predicates, - constraint_ids: self.narrowing_constraints.iter_predicates(constraint), + reachability_constraints: &self.reachability_constraints, }) } ConstraintKey::NestedScope(nested_scope) => { @@ -644,7 +637,6 @@ impl<'db> UseDefMap<'db> { BindingWithConstraintsIterator { all_definitions: &self.all_definitions, predicates: &self.predicates, - narrowing_constraints: &self.narrowing_constraints, reachability_constraints: &self.reachability_constraints, boundness_analysis, inner: bindings.iter(), @@ -700,7 +692,6 @@ type EnclosingSnapshots = IndexVec pub(crate) struct BindingWithConstraintsIterator<'map, 'db> { all_definitions: &'map IndexVec>, pub(crate) predicates: &'map Predicates<'db>, - pub(crate) narrowing_constraints: &'map NarrowingConstraints, pub(crate) reachability_constraints: &'map ReachabilityConstraints, pub(crate) boundness_analysis: BoundnessAnalysis, inner: LiveBindingsIterator<'map>, @@ -711,16 +702,16 @@ impl<'map, 'db> Iterator for BindingWithConstraintsIterator<'map, 'db> { fn next(&mut self) -> Option { let predicates = self.predicates; - let narrowing_constraints = self.narrowing_constraints; + let reachability_constraints = self.reachability_constraints; self.inner .next() .map(|live_binding| BindingWithConstraints { binding: self.all_definitions[live_binding.binding], - narrowing_constraint: ConstraintsIterator { + narrowing_constraint: NarrowingEvaluator { + constraint: live_binding.narrowing_constraint, predicates, - constraint_ids: narrowing_constraints - .iter_predicates(live_binding.narrowing_constraint), + reachability_constraints, }, reachability_constraint: live_binding.reachability_constraint, }) @@ -731,50 +722,30 @@ impl std::iter::FusedIterator for BindingWithConstraintsIterator<'_, '_> {} pub(crate) struct BindingWithConstraints<'map, 'db> { pub(crate) binding: DefinitionState<'db>, - pub(crate) narrowing_constraint: ConstraintsIterator<'map, 'db>, + pub(crate) narrowing_constraint: NarrowingEvaluator<'map, 'db>, pub(crate) reachability_constraint: ScopedReachabilityConstraintId, } -pub(crate) struct ConstraintsIterator<'map, 'db> { +pub(crate) struct NarrowingEvaluator<'map, 'db> { + pub(crate) constraint: ScopedNarrowingConstraint, predicates: &'map Predicates<'db>, - constraint_ids: NarrowingConstraintsIterator<'map>, -} - -impl<'db> Iterator for ConstraintsIterator<'_, 'db> { - type Item = Predicate<'db>; - - fn next(&mut self) -> Option { - self.constraint_ids - .next() - .map(|narrowing_constraint| self.predicates[narrowing_constraint.predicate()]) - } + reachability_constraints: &'map ReachabilityConstraints, } -impl std::iter::FusedIterator for ConstraintsIterator<'_, '_> {} - -impl<'db> ConstraintsIterator<'_, 'db> { +impl<'db> NarrowingEvaluator<'_, 'db> { pub(crate) fn narrow( self, db: &'db dyn crate::Db, base_ty: Type<'db>, place: ScopedPlaceId, ) -> Type<'db> { - // Constraints are in reverse-source order. Due to TypeGuard semantics - // constraint AND is non-commutative and so we _must_ apply in - // source order. - // - // Fortunately, constraint AND is still associative, so we can still iterate left-to-right - // and accumulate rightward. - self.filter_map(|constraint| infer_narrowing_constraint(db, constraint, place)) - .reduce(|acc, constraint| { - // See above---note the reverse application - constraint.merge_constraint_and(acc, db) - }) - .map_or(base_ty, |constraint| { - NarrowingConstraint::intersection(base_ty) - .merge_constraint_and(constraint, db) - .evaluate_constraint_type(db) - }) + self.reachability_constraints.narrow_by_constraint( + db, + self.predicates, + self.constraint, + base_ty, + place, + ) } } @@ -842,9 +813,6 @@ pub(super) struct UseDefMapBuilder<'db> { /// Builder of predicates. pub(super) predicates: PredicatesBuilder<'db>, - /// Builder of narrowing constraints. - pub(super) narrowing_constraints: NarrowingConstraintsBuilder, - /// Builder of reachability constraints. pub(super) reachability_constraints: ReachabilityConstraintsBuilder, @@ -887,7 +855,6 @@ impl<'db> UseDefMapBuilder<'db> { Self { all_definitions: IndexVec::from_iter([DefinitionState::Undefined]), predicates: PredicatesBuilder::default(), - narrowing_constraints: NarrowingConstraintsBuilder::default(), reachability_constraints: ReachabilityConstraintsBuilder::default(), bindings_by_use: IndexVec::new(), reachability: ScopedReachabilityConstraintId::ALWAYS_TRUE, @@ -1017,22 +984,53 @@ impl<'db> UseDefMapBuilder<'db> { return; } - let narrowing_constraint = predicate.into(); + let atom = self.reachability_constraints.add_atom(predicate); + self.record_narrowing_constraint_node_for_places(atom, places); + } + + /// Records a negated narrowing constraint for only the specified places. + /// + /// Uses TDD-level negation (`add_not_constraint`) rather than creating a new predicate atom + /// for the negated predicate. This ensures that `atom(P) OR NOT(atom(P))` simplifies to + /// `ALWAYS_TRUE` in the TDD, so narrowing is correctly cancelled out after complete + /// if/else blocks. + pub(super) fn record_negated_narrowing_constraint_for_places( + &mut self, + predicate: ScopedPredicateId, + places: &PossiblyNarrowedPlaces, + ) { + if predicate == ScopedPredicateId::ALWAYS_TRUE + || predicate == ScopedPredicateId::ALWAYS_FALSE + { + return; + } + + let atom = self.reachability_constraints.add_atom(predicate); + let negated = self.reachability_constraints.add_not_constraint(atom); + self.record_narrowing_constraint_node_for_places(negated, places); + } + + /// Records a TDD narrowing constraint node for the specified places. + fn record_narrowing_constraint_node_for_places( + &mut self, + constraint: ScopedNarrowingConstraint, + places: &PossiblyNarrowedPlaces, + ) { for place in places { match place { ScopedPlaceId::Symbol(symbol_id) => { if let Some(state) = self.symbol_states.get_mut(*symbol_id) { state.record_narrowing_constraint( - &mut self.narrowing_constraints, - narrowing_constraint, + &mut self.reachability_constraints, + constraint, ); } } ScopedPlaceId::Member(member_id) => { if let Some(state) = self.member_states.get_mut(*member_id) { state.record_narrowing_constraint( - &mut self.narrowing_constraints, - narrowing_constraint, + &mut self.reachability_constraints, + constraint, ); } } @@ -1112,11 +1110,7 @@ impl<'db> UseDefMapBuilder<'db> { negated_reachability_id, ); - self.symbol_states[symbol].merge( - post_definition_state, - &mut self.narrowing_constraints, - &mut self.reachability_constraints, - ); + self.symbol_states[symbol].merge(post_definition_state, &mut self.reachability_constraints); // And similarly for all associated members: for (member_id, pre_definition_member_state) in pre_definition.associated_member_states { @@ -1135,11 +1129,26 @@ impl<'db> UseDefMapBuilder<'db> { negated_reachability_id, ); - self.member_states[member_id].merge( - post_definition_state, - &mut self.narrowing_constraints, - &mut self.reachability_constraints, - ); + self.member_states[member_id] + .merge(post_definition_state, &mut self.reachability_constraints); + } + } + + /// Records a narrowing constraint for all places in the current scope. + /// + /// This is used to gate narrowing by `ReturnsNever` constraints: when a branch contains + /// a call to a `NoReturn` function, all narrowing in that branch should be conditional + /// on the call actually returning `Never`. + pub(super) fn record_narrowing_constraint_for_all_places( + &mut self, + constraint: ScopedNarrowingConstraint, + ) { + for state in &mut self.symbol_states { + state.record_narrowing_constraint(&mut self.reachability_constraints, constraint); + } + + for state in &mut self.member_states { + state.record_narrowing_constraint(&mut self.reachability_constraints, constraint); } } @@ -1307,12 +1316,11 @@ impl<'db> UseDefMapBuilder<'db> { let new_symbol_state = &self.symbol_states[enclosing_symbol]; bindings.merge( new_symbol_state.bindings().clone(), - &mut self.narrowing_constraints, &mut self.reachability_constraints, ); } Some(EnclosingSnapshot::Constraint(constraint)) => { - *constraint = ScopedNarrowingConstraint::empty(); + *constraint = ScopedNarrowingConstraint::ALWAYS_TRUE; } None => {} } @@ -1379,15 +1387,10 @@ impl<'db> UseDefMapBuilder<'db> { let mut snapshot_definitions_iter = snapshot.symbol_states.into_iter(); for current in &mut self.symbol_states { if let Some(snapshot) = snapshot_definitions_iter.next() { - current.merge( - snapshot, - &mut self.narrowing_constraints, - &mut self.reachability_constraints, - ); + current.merge(snapshot, &mut self.reachability_constraints); } else { current.merge( PlaceState::undefined(snapshot.reachability), - &mut self.narrowing_constraints, &mut self.reachability_constraints, ); // Place not present in snapshot, so it's unbound/undeclared from that path. @@ -1397,15 +1400,10 @@ impl<'db> UseDefMapBuilder<'db> { let mut snapshot_definitions_iter = snapshot.member_states.into_iter(); for current in &mut self.member_states { if let Some(snapshot) = snapshot_definitions_iter.next() { - current.merge( - snapshot, - &mut self.narrowing_constraints, - &mut self.reachability_constraints, - ); + current.merge(snapshot, &mut self.reachability_constraints); } else { current.merge( PlaceState::undefined(snapshot.reachability), - &mut self.narrowing_constraints, &mut self.reachability_constraints, ); // Place not present in snapshot, so it's unbound/undeclared from that path. @@ -1477,7 +1475,6 @@ impl<'db> UseDefMapBuilder<'db> { UseDefMap { all_definitions: self.all_definitions, predicates: self.predicates.build(), - narrowing_constraints: self.narrowing_constraints.build(), reachability_constraints: self.reachability_constraints.build(), bindings_by_use: self.bindings_by_use, node_reachability: self.node_reachability, diff --git a/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs b/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs index 4695bda41d504..033f0aa5426d3 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs @@ -46,9 +46,7 @@ use itertools::{EitherOrBoth, Itertools}; use ruff_index::newtype_index; use smallvec::{SmallVec, smallvec}; -use crate::semantic_index::narrowing_constraints::{ - NarrowingConstraintsBuilder, ScopedNarrowingConstraint, ScopedNarrowingConstraintPredicate, -}; +use crate::semantic_index::narrowing_constraints::ScopedNarrowingConstraint; use crate::semantic_index::reachability_constraints::{ ReachabilityConstraintsBuilder, ScopedReachabilityConstraintId, }; @@ -195,7 +193,9 @@ pub(super) enum EnclosingSnapshot { impl EnclosingSnapshot { pub(super) fn finish(&mut self, reachability_constraints: &mut ReachabilityConstraintsBuilder) { match self { - Self::Constraint(_) => {} + Self::Constraint(constraint) => { + reachability_constraints.mark_used(*constraint); + } Self::Bindings(bindings) => { bindings.finish(reachability_constraints); } @@ -226,6 +226,7 @@ impl Bindings { self.live_bindings.shrink_to_fit(); for binding in &self.live_bindings { reachability_constraints.mark_used(binding.reachability_constraint); + reachability_constraints.mark_used(binding.narrowing_constraint); } } } @@ -244,7 +245,7 @@ impl Bindings { pub(super) fn unbound(reachability_constraint: ScopedReachabilityConstraintId) -> Self { let initial_binding = LiveBinding { binding: ScopedDefinitionId::UNBOUND, - narrowing_constraint: ScopedNarrowingConstraint::empty(), + narrowing_constraint: ScopedNarrowingConstraint::ALWAYS_TRUE, reachability_constraint, }; Self { @@ -274,7 +275,7 @@ impl Bindings { } self.live_bindings.push(LiveBinding { binding, - narrowing_constraint: ScopedNarrowingConstraint::empty(), + narrowing_constraint: ScopedNarrowingConstraint::ALWAYS_TRUE, reachability_constraint, }); } @@ -282,12 +283,12 @@ impl Bindings { /// Add given constraint to all live bindings. pub(super) fn record_narrowing_constraint( &mut self, - narrowing_constraints: &mut NarrowingConstraintsBuilder, - predicate: ScopedNarrowingConstraintPredicate, + reachability_constraints: &mut ReachabilityConstraintsBuilder, + constraint: ScopedNarrowingConstraint, ) { for binding in &mut self.live_bindings { - binding.narrowing_constraint = narrowing_constraints - .add_predicate_to_constraint(binding.narrowing_constraint, predicate); + binding.narrowing_constraint = reachability_constraints + .add_and_constraint(binding.narrowing_constraint, constraint); } } @@ -311,7 +312,6 @@ impl Bindings { pub(super) fn merge( &mut self, b: Self, - narrowing_constraints: &mut NarrowingConstraintsBuilder, reachability_constraints: &mut ReachabilityConstraintsBuilder, ) { let a = std::mem::take(self); @@ -321,26 +321,25 @@ impl Bindings { .zip(b.unbound_narrowing_constraint) { self.unbound_narrowing_constraint = - Some(narrowing_constraints.intersect_constraints(a, b)); + Some(reachability_constraints.add_or_constraint(a, b)); } // Invariant: merge_join_by consumes the two iterators in sorted order, which ensures that // the merged `live_bindings` vec remains sorted. If a definition is found in both `a` and - // `b`, we compose the constraints from the two paths in an appropriate way (intersection - // for narrowing constraints; ternary OR for reachability constraints). If a definition is - // found in only one path, it is used as-is. + // `b`, we compose the constraints from the two paths using ternary OR for both narrowing + // and reachability constraints. If a definition is found in only one path, it is used + // as-is. let a = a.live_bindings.into_iter(); let b = b.live_bindings.into_iter(); for zipped in a.merge_join_by(b, |a, b| a.binding.cmp(&b.binding)) { match zipped { EitherOrBoth::Both(a, b) => { - // If the same definition is visible through both paths, any constraint - // that applies on only one path is irrelevant to the resulting type from - // unioning the two paths, so we intersect the constraints. - let narrowing_constraint = narrowing_constraints - .intersect_constraints(a.narrowing_constraint, b.narrowing_constraint); + // If the same definition is visible through both paths, we OR the narrowing + // constraints: the type should be narrowed by whichever path was taken. + let narrowing_constraint = reachability_constraints + .add_or_constraint(a.narrowing_constraint, b.narrowing_constraint); - // For reachability constraints, we merge them using a ternary OR operation: + // For reachability constraints, we also merge using a ternary OR operation: let reachability_constraint = reachability_constraints .add_or_constraint(a.reachability_constraint, b.reachability_constraint); @@ -395,11 +394,11 @@ impl PlaceState { /// Add given constraint to all live bindings. pub(super) fn record_narrowing_constraint( &mut self, - narrowing_constraints: &mut NarrowingConstraintsBuilder, - constraint: ScopedNarrowingConstraintPredicate, + reachability_constraints: &mut ReachabilityConstraintsBuilder, + constraint: ScopedNarrowingConstraint, ) { self.bindings - .record_narrowing_constraint(narrowing_constraints, constraint); + .record_narrowing_constraint(reachability_constraints, constraint); } /// Add given reachability constraint to all live bindings. @@ -431,11 +430,9 @@ impl PlaceState { pub(super) fn merge( &mut self, b: PlaceState, - narrowing_constraints: &mut NarrowingConstraintsBuilder, reachability_constraints: &mut ReachabilityConstraintsBuilder, ) { - self.bindings - .merge(b.bindings, narrowing_constraints, reachability_constraints); + self.bindings.merge(b.bindings, reachability_constraints); self.declarations .merge(b.declarations, reachability_constraints); } @@ -462,29 +459,17 @@ mod tests { use crate::semantic_index::predicate::ScopedPredicateId; #[track_caller] - fn assert_bindings( - narrowing_constraints: &NarrowingConstraintsBuilder, - place: &PlaceState, - expected: &[&str], - ) { - let actual = place + fn assert_bindings(place: &PlaceState, expected: &[(u32, ScopedNarrowingConstraint)]) { + let actual: Vec<(u32, ScopedNarrowingConstraint)> = place .bindings() .iter() .map(|live_binding| { - let def_id = live_binding.binding; - let def = if def_id == ScopedDefinitionId::UNBOUND { - "unbound".into() - } else { - def_id.as_u32().to_string() - }; - let predicates = narrowing_constraints - .iter_predicates(live_binding.narrowing_constraint) - .map(|idx| idx.as_u32().to_string()) - .collect::>() - .join(", "); - format!("{def}<{predicates}>") + ( + live_binding.binding.as_u32(), + live_binding.narrowing_constraint, + ) }) - .collect::>(); + .collect(); assert_eq!(actual, expected); } @@ -511,15 +496,13 @@ mod tests { #[test] fn unbound() { - let narrowing_constraints = NarrowingConstraintsBuilder::default(); let sym = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); - assert_bindings(&narrowing_constraints, &sym, &["unbound<>"]); + assert_bindings(&sym, &[(0, ScopedNarrowingConstraint::ALWAYS_TRUE)]); } #[test] fn with() { - let narrowing_constraints = NarrowingConstraintsBuilder::default(); let mut sym = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); sym.record_binding( ScopedDefinitionId::from_u32(1), @@ -528,12 +511,12 @@ mod tests { true, ); - assert_bindings(&narrowing_constraints, &sym, &["1<>"]); + assert_bindings(&sym, &[(1, ScopedNarrowingConstraint::ALWAYS_TRUE)]); } #[test] fn record_constraint() { - let mut narrowing_constraints = NarrowingConstraintsBuilder::default(); + let mut reachability_constraints = ReachabilityConstraintsBuilder::default(); let mut sym = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); sym.record_binding( ScopedDefinitionId::from_u32(1), @@ -541,15 +524,14 @@ mod tests { false, true, ); - let predicate = ScopedPredicateId::new(0).into(); - sym.record_narrowing_constraint(&mut narrowing_constraints, predicate); + let atom = reachability_constraints.add_atom(ScopedPredicateId::new(0)); + sym.record_narrowing_constraint(&mut reachability_constraints, atom); - assert_bindings(&narrowing_constraints, &sym, &["1<0>"]); + assert_bindings(&sym, &[(1, atom)]); } #[test] fn merge() { - let mut narrowing_constraints = NarrowingConstraintsBuilder::default(); let mut reachability_constraints = ReachabilityConstraintsBuilder::default(); // merging the same definition with the same constraint keeps the constraint @@ -560,8 +542,8 @@ mod tests { false, true, ); - let predicate = ScopedPredicateId::new(0).into(); - sym1a.record_narrowing_constraint(&mut narrowing_constraints, predicate); + let atom0 = reachability_constraints.add_atom(ScopedPredicateId::new(0)); + sym1a.record_narrowing_constraint(&mut reachability_constraints, atom0); let mut sym1b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); sym1b.record_binding( @@ -570,18 +552,14 @@ mod tests { false, true, ); - let predicate = ScopedPredicateId::new(0).into(); - sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate); + sym1b.record_narrowing_constraint(&mut reachability_constraints, atom0); - sym1a.merge( - sym1b, - &mut narrowing_constraints, - &mut reachability_constraints, - ); + sym1a.merge(sym1b, &mut reachability_constraints); let mut sym1 = sym1a; - assert_bindings(&narrowing_constraints, &sym1, &["1<0>"]); + // Same constraint on both sides → OR(atom0, atom0) = atom0 + assert_bindings(&sym1, &[(1, atom0)]); - // merging the same definition with differing constraints drops all constraints + // merging the same definition with differing constraints produces OR (not empty) let mut sym2a = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); sym2a.record_binding( ScopedDefinitionId::from_u32(2), @@ -589,8 +567,8 @@ mod tests { false, true, ); - let predicate = ScopedPredicateId::new(1).into(); - sym2a.record_narrowing_constraint(&mut narrowing_constraints, predicate); + let atom1 = reachability_constraints.add_atom(ScopedPredicateId::new(1)); + sym2a.record_narrowing_constraint(&mut reachability_constraints, atom1); let mut sym1b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); sym1b.record_binding( @@ -599,16 +577,17 @@ mod tests { false, true, ); - let predicate = ScopedPredicateId::new(2).into(); - sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate); + let atom2 = reachability_constraints.add_atom(ScopedPredicateId::new(2)); + sym1b.record_narrowing_constraint(&mut reachability_constraints, atom2); - sym2a.merge( - sym1b, - &mut narrowing_constraints, - &mut reachability_constraints, - ); + sym2a.merge(sym1b, &mut reachability_constraints); let sym2 = sym2a; - assert_bindings(&narrowing_constraints, &sym2, &["2<>"]); + // Different constraints: OR(atom1, atom2) produces a new TDD node (not a terminal) + let merged_constraint = sym2.bindings().iter().next().unwrap().narrowing_constraint; + assert_ne!(merged_constraint, ScopedNarrowingConstraint::ALWAYS_TRUE); + assert_ne!(merged_constraint, ScopedNarrowingConstraint::ALWAYS_FALSE); + assert_ne!(merged_constraint, atom1); + assert_ne!(merged_constraint, atom2); // merging a constrained definition with unbound keeps both let mut sym3a = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); @@ -618,27 +597,37 @@ mod tests { false, true, ); - let predicate = ScopedPredicateId::new(3).into(); - sym3a.record_narrowing_constraint(&mut narrowing_constraints, predicate); + let atom3 = reachability_constraints.add_atom(ScopedPredicateId::new(3)); + sym3a.record_narrowing_constraint(&mut reachability_constraints, atom3); let sym2b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); - sym3a.merge( - sym2b, - &mut narrowing_constraints, - &mut reachability_constraints, - ); + sym3a.merge(sym2b, &mut reachability_constraints); let sym3 = sym3a; - assert_bindings(&narrowing_constraints, &sym3, &["unbound<>", "3<3>"]); + let bindings: Vec<_> = sym3 + .bindings() + .iter() + .map(|b| (b.binding.as_u32(), b.narrowing_constraint)) + .collect(); + assert_eq!(bindings.len(), 2); + assert_eq!(bindings[0].0, 0); // unbound + assert_eq!(bindings[1].0, 3); + assert_eq!(bindings[1].1, atom3); // merging different definitions keeps them each with their existing constraints - sym1.merge( - sym3, - &mut narrowing_constraints, - &mut reachability_constraints, - ); + sym1.merge(sym3, &mut reachability_constraints); let sym = sym1; - assert_bindings(&narrowing_constraints, &sym, &["unbound<>", "1<0>", "3<3>"]); + let bindings: Vec<_> = sym + .bindings() + .iter() + .map(|b| (b.binding.as_u32(), b.narrowing_constraint)) + .collect(); + assert_eq!(bindings.len(), 3); + assert_eq!(bindings[0].0, 0); // unbound + assert_eq!(bindings[1].0, 1); + assert_eq!(bindings[1].1, atom0); + assert_eq!(bindings[2].0, 3); + assert_eq!(bindings[2].1, atom3); } #[test] @@ -676,7 +665,6 @@ mod tests { #[test] fn record_declaration_merge() { - let mut narrowing_constraints = NarrowingConstraintsBuilder::default(); let mut reachability_constraints = ReachabilityConstraintsBuilder::default(); let mut sym = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); sym.record_declaration( @@ -690,18 +678,13 @@ mod tests { ScopedReachabilityConstraintId::ALWAYS_TRUE, ); - sym.merge( - sym2, - &mut narrowing_constraints, - &mut reachability_constraints, - ); + sym.merge(sym2, &mut reachability_constraints); assert_declarations(&sym, &["1", "2"]); } #[test] fn record_declaration_merge_partial_undeclared() { - let mut narrowing_constraints = NarrowingConstraintsBuilder::default(); let mut reachability_constraints = ReachabilityConstraintsBuilder::default(); let mut sym = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); sym.record_declaration( @@ -711,11 +694,7 @@ mod tests { let sym2 = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); - sym.merge( - sym2, - &mut narrowing_constraints, - &mut reachability_constraints, - ); + sym.merge(sym2, &mut reachability_constraints); assert_declarations(&sym, &["undeclared", "1"]); }