diff --git a/crates/ty_python_semantic/resources/mdtest/loops/for.md b/crates/ty_python_semantic/resources/mdtest/loops/for.md index eba60a99fca4c..4996029c2975b 100644 --- a/crates/ty_python_semantic/resources/mdtest/loops/for.md +++ b/crates/ty_python_semantic/resources/mdtest/loops/for.md @@ -1051,3 +1051,276 @@ def _(value: list[Not[str]]): for x in value: reveal_type(x) # revealed: ~str ``` + +## Walrus definitions in the iterator expression are always evaluated + +```py +for _ in (x := range(0)): + pass +reveal_type(x) # revealed: range +``` + +## Cyclic control flow + +### Basic + +```py +i = 0 +reveal_type(i) # revealed: Literal[0] +for _ in range(1_000_000): + i += 1 + reveal_type(i) # revealed: int +reveal_type(i) # revealed: int +``` + +### A binding that didn't exist before the loop started + +```py +i = 0 +for _ in range(1_000_000): + if i > 0: + loop_only += 1 # error: [possibly-unresolved-reference] + if i == 0: + loop_only = 0 + i += 1 +# error: [possibly-unresolved-reference] +reveal_type(loop_only) # revealed: int +``` + +### Nested loops with `break` and `continue` + +```py +def random() -> bool: + return False + +x = "A" +for _ in range(1_000_000): + reveal_type(x) # revealed: Literal["A", "D"] + for _ in range(1_000_000): + # The "C" binding isn't visible here. It breaks this inner loop, and it always gets + # overwritten before the end of the outer loop. + reveal_type(x) # revealed: Literal["A", "D", "B"] + if random(): + x = "B" + continue + else: + x = "C" + break + reveal_type(x) # revealed: Never + # We don't know whether a `for` loop will execute its body at all, so "A" is still visible here. + # Similarly, we don't know when the loop will terminate, so "B" is also visible here despite the + # `continue` above. + reveal_type(x) # revealed: Literal["A", "D", "B", "C"] + if random(): + x = "D" + continue + else: + x = "E" + break + reveal_type(x) # revealed: Never +reveal_type(x) # revealed: Literal["A", "D", "E"] +``` + +### Walrus operator assignments are visible via loopback + +```py +for _ in range(1_000_000): + # error: [possibly-unresolved-reference] + reveal_type(y) # revealed: Literal[1] + x = (y := 1) +``` + +### Loopback bindings are not visible to the walrus operator in iterable expression + +The iterable is only evaluated once, before the loop body runs. + +```py +x = "hello" +for _ in (y := x): + # This assignment is not visible when the iterable `x` is used above. + x = None +reveal_type(y) # revealed: Literal["hello"] +``` + +### "Member" (as opposed to "symbol") places are also given loopback bindings + +```py +my_dict = {} +my_dict["x"] = 0 +reveal_type(my_dict["x"]) # revealed: Literal[0] +for _ in range(1_000_000): + my_dict["x"] += 1 +reveal_type(my_dict["x"]) # revealed: int +``` + +### `del` prevents bindings from reaching the loopback + +This `x` cannot reach the use at the top of the loop: + +```py +for _ in range(1_000_000): + x # error: [unresolved-reference] + x = 42 + del x +``` + +On the other hand, if `x` is defined before the loop, the `del` makes it a +`[possibly-unresolved-reference]`: + +```py +x = 0 +for _ in range(1_000_000): + x # error: [possibly-unresolved-reference] + x = 42 + del x +``` + +### `del` in a loop makes a variable possibly-unbound after the loop + +```py +x = 0 +for _ in range(1_000_000): + # error: [possibly-unresolved-reference] + del x +# error: [possibly-unresolved-reference] +x +``` + +### Bindings in a loop are possibly-unbound after the loop + +```py +for _ in range(1_000_000): + x = 42 +# error: [possibly-unresolved-reference] +x +``` + +### Swap bindings converge normally under fixpoint iteration + +```py +x = 1 +y = 2 +for _ in range(1_000_000): + x, y = y, x + # Note that we get correct types in the "avoid oscillations" test case below, but not here. I + # believe the difference is that in this case the Salsa "cycle head" is the tuple on the RHS of + # the assignment, which triggers our recursive type handling, whereas below it's `x`. + # TODO: should be Literal[2, 1] + reveal_type(x) # revealed: Divergent + # TODO: should be Literal[1, 2] + reveal_type(y) # revealed: Divergent +``` + +### Tuple assignments are inferred correctly + +```py +x = 0 +for _ in range(1_000_000): + x, y = x + 1, None + # TODO: should be int + reveal_type(x) # revealed: Divergent +``` + +### Avoid oscillations + +We need to avoid oscillating cycles in cases like the following, where the type of one of these loop +variables also influences the static reachability of its bindings. This case was minimized from a +real crash that came up during development checking these lines of `sympy`: + + +```py +x = 1 +y = 2 +for _ in range(1_000_000): + if x: + x, y = y, x + reveal_type(x) # revealed: Literal[2, 1] + reveal_type(y) # revealed: Literal[1, 2] +``` + +### Bindings in statically unreachable branches are excluded from loopback + +```py +VAL = 1 + +x = 1 +for _ in range(1_000_000): + reveal_type(x) # revealed: Literal[1] + if VAL - 1: + x = 2 +``` + +### `Divergent` in narrowing conditions doesn't run afoul of "monotonic widening" in cycle recovery + +This test looks for a complicated inference failure case that came up during implementation. See the +`while` variant of this case in `while_loop.md` for a detailed description. + +```py +class Node: + def __init__(self, next: "Node | None" = None): + self.next: "Node | None" = next + +node = Node(Node(Node())) +for _ in range(1_000_000): + if node.next is None: + break + node = node.next +reveal_type(node) # revealed: Node +reveal_type(node.next) # revealed: Node | None +``` + +### `global` and `nonlocal` keywords in a loop + +We need to make sure that the loop header definition doesn't count as a "use" prior to the +`global`/`nonlocal` declaration, or else we'll emit a false-positive semantic syntax error. + +```py +x = 0 + +def _(): + y = 0 + def _(): + for _ in range(1_000_000): + global x + nonlocal y + x = 42 + y = 99 +``` + +On the other hand, we don't want to shadow true positives: + +```py +x = 0 + +def _(): + y = 0 + def _(): + x = 1 + y = 1 + for _ in range(1_000_000): + global x # error: [invalid-syntax] "name `x` is used prior to global declaration" + nonlocal y # error: [invalid-syntax] "name `y` is used prior to nonlocal declaration" +``` + +### Loop header definitions don't shadow member bindings + +```py +class C: + x = None + +c = C() +c.x = 0 + +for _ in range(1): + reveal_type(c.x) # revealed: Literal[0] + c = C() + break + +d = [0] +d[0] = 1 + +for _ in range(1): + reveal_type(d[0]) # revealed: Literal[1] + d = [] + break +``` diff --git a/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md b/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md index 41e48b1404fc4..a14f891d932f2 100644 --- a/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md +++ b/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md @@ -127,3 +127,474 @@ class NotBoolable: while NotBoolable(): ... ``` + +## Walrus definitions in the condition are always evaluated + +```py +while x := False: + pass +reveal_type(x) # revealed: Literal[False] +``` + +## Cyclic control flow + +### Basic + +```py +def random() -> bool: + return False + +i = 0 +reveal_type(i) # revealed: Literal[0] +while random(): + i += 1 + reveal_type(i) # revealed: int +reveal_type(i) # revealed: int +``` + +### A binding that didn't exist before the loop started + +```py +i = 0 +while i < 1_000_000: + if i > 0: + loop_only += 1 # error: [possibly-unresolved-reference] + if i == 0: + loop_only = 0 + i += 1 +# error: [possibly-unresolved-reference] +reveal_type(loop_only) # revealed: int +``` + +### A more complex example + +Here the loop condition narrows both the loop-back value and the end-of-loop value: + +```py +def random() -> bool: + return False + +x = "A" +while x != "C": + reveal_type(x) # revealed: Literal["A", "B"] + if random(): + x = "B" + else: + x = "C" + reveal_type(x) # revealed: Literal["B", "C"] +reveal_type(x) # revealed: Literal["C"] +``` + +### An even more complex example + +```py +def random() -> bool: + return False + +x = "A" +while x != "E": + reveal_type(x) # revealed: Literal["A", "C", "D"] + while x != "C": + reveal_type(x) # revealed: Literal["A", "D", "B"] + if random(): + x = "B" + else: + x = "C" + reveal_type(x) # revealed: Literal["B", "C"] + reveal_type(x) # revealed: Literal["C"] + if random(): + x = "D" + if random(): + x = "E" + reveal_type(x) # revealed: Literal["C", "D", "E"] +reveal_type(x) # revealed: Literal["E"] +``` + +### `break` and `continue` + +```py +def random() -> bool: + return False + +x = "A" +while True: + reveal_type(x) # revealed: Literal["A", "C", "D"] + while True: + reveal_type(x) # revealed: Literal["A", "C", "D", "B"] + if random(): + x = "B" + continue + else: + x = "C" + break + reveal_type(x) # revealed: Never + reveal_type(x) # revealed: Literal["C"] + if random(): + x = "D" + continue + if random(): + x = "E" + break + reveal_type(x) # revealed: Literal["C"] +reveal_type(x) # revealed: Literal["E"] +``` + +### Interaction between `break` and a narrowing condition + +Here the loop condition forces `x` to be `False` at loop exit, because there is no `break`: + +```py +def random() -> bool: + return True + +x = random() +reveal_type(x) # revealed: bool +while x: + pass +reveal_type(x) # revealed: Literal[False] +``` + +However, we can't narrow `x` like this when there's a `break` in the loop: + +```py +x = random() +while x: + if random(): + break +reveal_type(x) # revealed: bool +``` + +### Non-static loop conditions + +```py +def random() -> bool: + return False + +x = "A" +while random(): + reveal_type(x) # revealed: Literal["A", "B", "C", "D"] + x = "B" + if random(): + x = "C" + if x == "C": + continue + reveal_type(x) # revealed: Literal["B"] + while random(): + reveal_type(x) # revealed: Literal["B", "D"] + if random(): + x = "D" + continue + x = "E" + break + reveal_type(x) # revealed: Literal["B", "D", "E"] + if x == "E": + break + reveal_type(x) # revealed: Literal["B", "D"] +reveal_type(x) # revealed: Literal["A", "B", "C", "D", "E"] +``` + +### Functions and classes defined in loops count as bindings and are visible via loopback + +```py +def random() -> bool: + return False + +foo = None +Bar = None +while random(): + reveal_type(foo) # revealed: None | (def foo() -> None) + reveal_type(Bar) # revealed: None | + + def foo() -> None: ... + + class Bar: ... +``` + +### Walrus operator assignments are visible via loopback + +```py +def random() -> bool: + return False + +while random(): + # error: [possibly-unresolved-reference] + reveal_type(y) # revealed: Literal[1] + x = (y := 1) +``` + +### Loopback bindings are visible to the walrus operator in the loop condition + +```py +i = 0 +while (i := i + 1) < 1_000_000: + reveal_type(i) # revealed: int +``` + +### "Member" (as opposed to "symbol") places are also given loopback bindings + +```py +def random() -> bool: + return False + +my_dict = {} +my_dict["x"] = 0 +reveal_type(my_dict["x"]) # revealed: Literal[0] +while random(): + my_dict["x"] += 1 +reveal_type(my_dict["x"]) # revealed: int +``` + +### `del` prevents bindings from reaching the loopback + +This `x` cannot reach the use at the top of the loop: + +```py +def random() -> bool: + return False + +while random(): + x # error: [unresolved-reference] + x = 42 + del x +``` + +On the other hand, if `x` is defined before the loop, the `del` makes it a +`[possibly-unresolved-reference]`: + +```py +x = 0 +while random(): + x # error: [possibly-unresolved-reference] + x = 42 + del x +``` + +### `del` in a loop makes a variable possibly-unbound after the loop + +```py +def random() -> bool: + return False + +x = 0 +while random(): + # error: [possibly-unresolved-reference] + del x +# error: [possibly-unresolved-reference] +x +``` + +### Bindings in a loop are possibly-unbound after the loop + +```py +def random() -> bool: + return False + +while random(): + x = 42 +# error: [possibly-unresolved-reference] +x +``` + +### Swap bindings converge normally under fixpoint iteration + +```py +def random() -> bool: + return False + +x = 1 +y = 2 +while random(): + x, y = y, x + # Note that we get correct types in the "avoid oscillations" test case below, but not here. I + # believe the difference is that in this case the Salsa "cycle head" is the tuple on the RHS of + # the assignment, which triggers our recursive type handling, whereas below it's `x`. + # TODO: should be Literal[2, 1] + reveal_type(x) # revealed: Divergent + # TODO: should be Literal[1, 2] + reveal_type(y) # revealed: Divergent +``` + +### Tuple assignments are inferred correctly + +```py +def random() -> bool: + return False + +x = 0 +while random(): + x, y = x + 1, None + # TODO: should be int + reveal_type(x) # revealed: Divergent +``` + +### Avoid oscillations + +We need to avoid oscillating cycles in cases like the following, where the type of one of these loop +variables also influences the static reachability of its bindings. This case was minimized from a +real crash that came up during development checking these lines of `sympy`: + + +```py +def random() -> bool: + return False + +x = 1 +y = 2 +while random(): + if x: + x, y = y, x + reveal_type(x) # revealed: Literal[2, 1] + reveal_type(y) # revealed: Literal[1, 2] +``` + +### Loop bodies that are guaranteed to execute at least once + +TODO: We should be able to see when a loop body is guaranteed to execute at least once. However, +Pyright and other checkers don't currently handle this case either. + +```py +x = "foo" +while x != "bar": + definitely_bound = 42 + x = "bar" +# TODO: We should see that `definitely_bound` is definitely bound. +# error: [possibly-unresolved-reference] +reveal_type(definitely_bound) # revealed: Literal[42] +``` + +### Bindings in statically unreachable branches are excluded from loopback + +```py +VAL = 1 + +x = 1 +while True: + reveal_type(x) # revealed: Literal[1] + if VAL - 1: + x = 2 +``` + +### `Divergent` in narrowing conditions doesn't run afoul of "monotonic widening" in cycle recovery + +The following is a deceptively-simple-looking case of narrowing that was difficult to get right in +the initial implementation of cyclic control flow. We start with a non-empty linked list, and we +advance it in a loop until there's exactly one node left: + +```py +class Node: + def __init__(self, next: "Node | None" = None): + self.next: "Node | None" = next + +node = Node(Node(Node())) +while node.next is not None: + node = node.next +reveal_type(node) # revealed: Node +reveal_type(node.next) # revealed: None +``` + +There's nothing wrong with this code, and it was minimized from [real cases] in the ecosystem. But +it's prone to false-positive `[possibly-missing-attribute]` warnings on the `node.next` accesses if +we lose track of the fact that the `node` variable is never `None`. Note that the loop condition +narrows `node.next`, not `node` itself, so that constraint needs to flow through the assignment in +the loop body, and through the loop header definition that sees that assignment, to the prior uses +of `node` in the loop condition and in the RHS of the assignment. We expect that to become a Salsa +cycle that we resolve through fixpoint iteration. That runs into two of our cycle recovery +behaviors: + +1. When cycles show up in a standalone expression definition (in this case, the `while` loop + condition), the `cycle_initial` value (`expression_cycle_initial`) is an empty map with a + "fallback type" that reports `Divergent` for _every_ sub-expression. That even includes literal + expressions like `42` and (in this case) `None`. +1. To avoid oscillations in cycle recovery (`Type::cycle_normalized`), we union together the type + inferred in the previous iteration with the type inferred in the current one, as long as + neither of them contains `Divergent`. In other words, we do "monotonic widening". + +The interaction we have to worry about is getting stuck with a type that's too wide. When we try to +do narrowing in the first cycle iteration, `is not None` behaves like `is not Divergent`. If the +consequence is that we don't do any narrowing at all, then for that iteration we'll end up inferring +`Node | None` for `node`. (For completeness, we actually infer `Node | None | Divergent` because of +a nested cycle, but we strip out _that_ `Divergent` in another part of cycle recovery. The +[full chain of events here][divergent_debugging] is quite long.) In the second cycle iteration we'll +get the narrowing right and infer that `node` is of type `Node`, but then our monotonic widening +step will union `Node` with `Node | None` from the previous iteration, reproduce the same wrong +answer, and declare that to be the fixpoint. Finally we get false-positive warnings from the fact +that `Node` doesn't have a `.next` field. + +So, because we do monotonic widening in cycle recovery, we need to make sure that temporarily +`Divergent` expressions in narrowing constraints don't lead to too-wide-but-not-visibly-`Divergent` +types. Instead, `Divergent` should "poison" any value we try to narrow against it, so that our cycle +recovery logic doesn't carry that result forward. + +### `global` and `nonlocal` keywords in a loop + +We need to make sure that the loop header definition doesn't count as a "use" prior to the +`global`/`nonlocal` declaration, or else we'll emit a false-positive semantic syntax error: + +```py +x = 0 + +def _(): + y = 0 + def _(): + while True: + global x + nonlocal y + x = 42 + y = 99 +``` + +On the other hand, we don't want to shadow true positives: + +```py +x = 0 + +def _(): + y = 0 + def _(): + x = 1 + y = 1 + while True: + global x # error: [invalid-syntax] "name `x` is used prior to global declaration" + nonlocal y # error: [invalid-syntax] "name `y` is used prior to nonlocal declaration" +``` + +### Use with loop header and also `UNBOUND` definitely visible + +In `place_from_bindings_impl` we usually assert that if at least one (non-`UNBOUND`) binding is +visible, then `UNBOUND` should not be definitely-visible. That makes intuitive sense: either a +binding should shadow `UNBOUND` entirely, or if it was made in a branch then it should attach the +negated branch condition to `UNBOUND`. However, loop header bindings are an exception to this rule, +because they don't shadow prior bindings. In this example `UNBOUND` is definitely-visible, and we +need to avoid panicking: + +```py +while True: + x # error: [possibly-unresolved-reference] + x = 1 +``` + +### Loop header definitions don't shadow member bindings + +```py +class C: + x = None + +c = C() +c.x = 0 + +while True: + reveal_type(c.x) # revealed: Literal[0] + c = C() + break + +d = [0] +d[0] = 1 + +while True: + reveal_type(d[0]) # revealed: Literal[1] + d = [] + break +``` + +[divergent_debugging]: https://github.com/astral-sh/ruff/pull/22794#issuecomment-3852095578 +[real cases]: https://github.com/Finistere/antidote/blob/7d64ff76b7e283e5d9593ca09ea7a52b9b054957/src/antidote/_internal/localns.py#L34-L35 diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/while.md b/crates/ty_python_semantic/resources/mdtest/narrow/while.md index deae318666bff..390136cc7d1ef 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/while.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/while.md @@ -52,9 +52,7 @@ while x != 1: reveal_type(x) # revealed: Literal[2, 3] while x != 2: - # TODO: this should be Literal[1, 3]; Literal[3] is only correct - # in the first loop iteration - reveal_type(x) # revealed: Literal[3] + reveal_type(x) # revealed: Literal[3, 1] x = next_item() x = next_item() diff --git a/crates/ty_python_semantic/resources/mdtest/terminal_statements.md b/crates/ty_python_semantic/resources/mdtest/terminal_statements.md index 05bf71894ec61..63abf0e847182 100644 --- a/crates/ty_python_semantic/resources/mdtest/terminal_statements.md +++ b/crates/ty_python_semantic/resources/mdtest/terminal_statements.md @@ -144,10 +144,6 @@ are likely visible after the loop body, since loops do not introduce new scopes. infinite loops are one exception — if control never leaves the loop body, bindings inside of the loop are not visible outside of it.) -TODO: We are not currently modeling the cyclic control flow for loops, pending fixpoint support in -Salsa. The false positives in this section are because of that, and not our terminal statement -support. See [ruff#14160](https://github.com/astral-sh/ruff/issues/14160) for more details. - ```py def resolved_reference(cond: bool) -> str: while True: @@ -168,8 +164,7 @@ def continue_in_then_branch(cond: bool, i: int): x = "loop" reveal_type(x) # revealed: Literal["loop"] reveal_type(x) # revealed: Literal["loop"] - # TODO: Should be Literal["before", "loop", "continue"] - reveal_type(x) # revealed: Literal["before", "loop"] + reveal_type(x) # revealed: Literal["before", "continue", "loop"] def continue_in_else_branch(cond: bool, i: int): x = "before" @@ -182,8 +177,7 @@ def continue_in_else_branch(cond: bool, i: int): reveal_type(x) # revealed: Literal["continue"] continue reveal_type(x) # revealed: Literal["loop"] - # TODO: Should be Literal["before", "loop", "continue"] - reveal_type(x) # revealed: Literal["before", "loop"] + reveal_type(x) # revealed: Literal["before", "loop", "continue"] def continue_in_both_branches(cond: bool, i: int): x = "before" @@ -196,8 +190,7 @@ def continue_in_both_branches(cond: bool, i: int): x = "continue2" reveal_type(x) # revealed: Literal["continue2"] continue - # TODO: Should be Literal["before", "continue1", "continue2"] - reveal_type(x) # revealed: Literal["before"] + reveal_type(x) # revealed: Literal["before", "continue1", "continue2"] def continue_in_nested_then_branch(cond1: bool, cond2: bool, i: int): x = "before" @@ -215,8 +208,7 @@ def continue_in_nested_then_branch(cond1: bool, cond2: bool, i: int): reveal_type(x) # revealed: Literal["loop2"] reveal_type(x) # revealed: Literal["loop2"] reveal_type(x) # revealed: Literal["loop1", "loop2"] - # TODO: Should be Literal["before", "loop1", "loop2", "continue"] - reveal_type(x) # revealed: Literal["before", "loop1", "loop2"] + reveal_type(x) # revealed: Literal["before", "loop1", "continue", "loop2"] def continue_in_nested_else_branch(cond1: bool, cond2: bool, i: int): x = "before" @@ -234,8 +226,7 @@ def continue_in_nested_else_branch(cond1: bool, cond2: bool, i: int): continue reveal_type(x) # revealed: Literal["loop2"] reveal_type(x) # revealed: Literal["loop1", "loop2"] - # TODO: Should be Literal["before", "loop1", "loop2", "continue"] - reveal_type(x) # revealed: Literal["before", "loop1", "loop2"] + reveal_type(x) # revealed: Literal["before", "loop1", "loop2", "continue"] def continue_in_both_nested_branches(cond1: bool, cond2: bool, i: int): x = "before" @@ -253,8 +244,7 @@ def continue_in_both_nested_branches(cond1: bool, cond2: bool, i: int): reveal_type(x) # revealed: Literal["continue2"] continue reveal_type(x) # revealed: Literal["loop"] - # TODO: Should be Literal["before", "loop", "continue1", "continue2"] - reveal_type(x) # revealed: Literal["before", "loop"] + reveal_type(x) # revealed: Literal["before", "loop", "continue1", "continue2"] ``` ## `break` diff --git a/crates/ty_python_semantic/src/place.rs b/crates/ty_python_semantic/src/place.rs index 882701d1587fb..bf98399e58d70 100644 --- a/crates/ty_python_semantic/src/place.rs +++ b/crates/ty_python_semantic/src/place.rs @@ -5,11 +5,12 @@ use ty_module_resolver::{ }; use crate::dunder_all::dunder_all_names; -use crate::semantic_index::definition::{Definition, DefinitionState}; +use crate::semantic_index::definition::{Definition, DefinitionKind, DefinitionState}; use crate::semantic_index::place::{PlaceExprRef, ScopedPlaceId}; use crate::semantic_index::scope::ScopeId; use crate::semantic_index::{ - BindingWithConstraints, BindingWithConstraintsIterator, DeclarationsIterator, place_table, + BindingWithConstraints, BindingWithConstraintsIterator, DeclarationsIterator, get_loop_header, + place_table, }; use crate::semantic_index::{DeclarationWithConstraint, global_scope, use_def_map}; use crate::types::{ @@ -1162,6 +1163,7 @@ fn place_from_bindings_impl<'db>( bindings_with_constraints: BindingWithConstraintsIterator<'_, 'db>, requires_explicit_reexport: RequiresExplicitReExport, ) -> PlaceWithDefinition<'db> { + let all_definitions = bindings_with_constraints.all_definitions; let predicates = bindings_with_constraints.predicates; let reachability_constraints = bindings_with_constraints.reachability_constraints; let boundness_analysis = bindings_with_constraints.boundness_analysis; @@ -1191,6 +1193,7 @@ fn place_from_bindings_impl<'db>( }; let mut first_definition = None; + let mut only_loop_header_bindings = true; let mut types = bindings_with_constraints.filter_map( |BindingWithConstraints { @@ -1272,6 +1275,57 @@ fn place_from_bindings_impl<'db>( return None; } + // We need to "look through" loop header definitions to do boundness analysis. The + // actual type is computed by `infer_loop_header_definition` via `binding_type` below, + // like all other bindings, so that it can participate in fixpoint iteration. + if let DefinitionKind::LoopHeader(loop_header_kind) = binding.kind(db) { + let loop_header = get_loop_header(db, loop_header_kind.loop_token()); + let place = loop_header_kind.place(); + let mut has_defined_bindings = false; + for loop_back in loop_header.bindings_for_place(place) { + // Skip unreachable bindings. + if reachability_constraints + .evaluate(db, predicates, loop_back.reachability_constraint) + .is_always_false() + { + continue; + } + + // Resolve the definition state from the binding ID. + let def_state = all_definitions[loop_back.binding]; + + match def_state { + DefinitionState::Defined(_) => { + has_defined_bindings = true; + } + // `del` in the loop body is always visible to code after the loop via the + // normal control flow merge. Updating `deleted_reachability` here is + // necessary for prior uses in the loop to see it. + DefinitionState::Deleted => { + deleted_reachability = + deleted_reachability.or(reachability_constraints.evaluate( + db, + predicates, + loop_back.reachability_constraint, + )); + } + // If UNBOUND is visible at loop-back, then it was visible before the loop. + // Loop header definitions don't shadow preexisting bindings, so we don't + // need to do anything with this. + DefinitionState::Undefined => {} + } + } + // If all the bindings in the loop are in statically false branches, it might be + // that none of them loop-back. In that case short-circuit, so that we don't + // produce an `Unknown` fallback type, and so that `Place::Undefined` is still a + // possibility below. + if !has_defined_bindings { + return None; + } + } else { + only_loop_header_bindings = false; + } + first_definition.get_or_insert(binding); let binding_ty = binding_type(db, binding); Some(narrowing_constraint.narrow(db, binding_ty, binding.place(db))) @@ -1296,6 +1350,12 @@ fn place_from_bindings_impl<'db>( let boundness = match boundness_analysis { BoundnessAnalysis::AssumeBound => Definedness::AlwaysDefined, BoundnessAnalysis::BasedOnUnboundVisibility => match unbound_visibility() { + Some(Truthiness::AlwaysTrue) if only_loop_header_bindings => { + // Loop header definitions don't shadow prior bindings, so UNBOUND can still be + // definitely-visible alongside a loop header binding. See "Use with loop + // header and also `UNBOUND` definitely visible" in `while_loop.md`. + Definedness::PossiblyUndefined + } Some(Truthiness::AlwaysTrue) => { unreachable!( "If we have at least one binding, the implicit `unbound` binding should not be definitely visible" diff --git a/crates/ty_python_semantic/src/semantic_index.rs b/crates/ty_python_semantic/src/semantic_index.rs index fefa9be917410..e49d3188104ec 100644 --- a/crates/ty_python_semantic/src/semantic_index.rs +++ b/crates/ty_python_semantic/src/semantic_index.rs @@ -9,8 +9,11 @@ use ruff_python_parser::semantic_errors::SemanticSyntaxError; use rustc_hash::{FxHashMap, FxHashSet}; use salsa::Update; use salsa::plumbing::AsId; +use smallvec::SmallVec; use ty_module_resolver::ModuleName; +use crate::semantic_index::place::ScopedPlaceId; + use crate::Db; use crate::node_key::NodeKey; use crate::semantic_index::ast_ids::AstIds; @@ -44,7 +47,7 @@ mod use_def; pub(crate) use self::use_def::{ ApplicableConstraints, BindingWithConstraints, BindingWithConstraintsIterator, - DeclarationWithConstraint, DeclarationsIterator, + DeclarationWithConstraint, DeclarationsIterator, LiveBinding, }; /// Returns the semantic index for `file`. @@ -95,6 +98,95 @@ pub(crate) fn use_def_map<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc>, +} + +impl LoopHeader { + pub(crate) fn new() -> Self { + Self { + bindings: FxHashMap::default(), + } + } + + pub(crate) fn add_binding(&mut self, place: ScopedPlaceId, binding: LiveBinding) { + self.bindings.entry(place).or_default().push(binding); + } + + pub(crate) fn bindings_for_place( + &self, + place: ScopedPlaceId, + ) -> impl Iterator + '_ { + self.bindings + .get(&place) + .map(|v: &SmallVec<[LiveBinding; 1]>| v.iter().copied()) + .into_iter() + .flatten() + } +} + +/// A Salsa token for retrieving a `LoopHeader`. See `get_loop_header` below. +#[salsa::tracked(debug, heap_size=ruff_memory_usage::heap_size)] +pub struct LoopToken<'db> {} + +impl get_size2::GetSize for LoopToken<'_> {} + +/// Look up a `LoopHeader` given a `LoopToken`. +/// +/// Loop header definitions are the very first things we encounter (synthesize) when we walk a +/// loop, and they need to refer to the corresponding the `LoopHeader` struct that records their +/// bindings, but that struct isn't available until we've finished walking the loop. To make this +/// work in the largely immutable world of Salsa, we add a layer of indirection using a Salsa +/// feature called "specify": +/// +/// +/// When we first encounter a loop, we generate a `LoopToken` that uniquely identifies the loop but +/// doesn't contain any data. We do a lightweight pre-walk to collect bound places (see +/// `LoopBindingsVisitor`), and for each bound place we create a loop header definition that stores +/// the `LoopToken`. Then after we've finished visiting the loop, we call +/// `get_loop_header::specify` to associate the token with the completed `LoopHeader`. All of this +/// happens while we're building the semantic index, and nothing needs to call `get_loop_header` +/// until we get to type inference later, so the order of operations always works out. +#[salsa::tracked(specify, heap_size=ruff_memory_usage::heap_size)] +pub(crate) fn get_loop_header<'db>(_db: &'db dyn Db, _loop_token: LoopToken<'db>) -> LoopHeader { + panic!("should always be set by specify()"); +} + /// Returns all attribute assignments (and their method scope IDs) with a symbol name matching /// the one given for a specific class body scope. /// diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 2c70ca634a261..76cb3c3df1ef3 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -27,8 +27,8 @@ use crate::semantic_index::definition::{ ComprehensionDefinitionNodeRef, Definition, DefinitionCategory, DefinitionNodeKey, DefinitionNodeRef, Definitions, DictKeyAssignmentNodeRef, ExceptHandlerDefinitionNodeRef, ForStmtDefinitionNodeRef, ImportDefinitionNodeRef, ImportFromDefinitionNodeRef, - ImportFromSubmoduleDefinitionNodeRef, MatchPatternDefinitionNodeRef, - StarImportDefinitionNodeRef, WithItemDefinitionNodeRef, + ImportFromSubmoduleDefinitionNodeRef, LoopHeaderDefinitionNodeRef, LoopStmtRef, + MatchPatternDefinitionNodeRef, StarImportDefinitionNodeRef, WithItemDefinitionNodeRef, }; use crate::semantic_index::expression::{Expression, ExpressionKind}; use crate::semantic_index::member::MemberExprBuilder; @@ -47,26 +47,37 @@ use crate::semantic_index::scope::{ use crate::semantic_index::scope::{Scope, ScopeId, ScopeKind, ScopeLaziness}; use crate::semantic_index::symbol::{ScopedSymbolId, Symbol}; use crate::semantic_index::use_def::{ - EnclosingSnapshotKey, FlowSnapshot, ScopedEnclosingSnapshotId, UseDefMapBuilder, + EnclosingSnapshotKey, FlowSnapshot, PreviousDefinitions, ScopedEnclosingSnapshotId, + UseDefMapBuilder, +}; +use crate::semantic_index::{ + ExpressionsScopeMap, LoopHeader, LoopToken, SemanticIndex, VisibleAncestorsIter, + get_loop_header, }; -use crate::semantic_index::{ExpressionsScopeMap, SemanticIndex, VisibleAncestorsIter}; use crate::semantic_model::HasTrackedScope; use crate::types::PossiblyNarrowedPlaces; use crate::unpack::{EvaluationMode, Unpack, UnpackKind, UnpackPosition, UnpackValue}; use crate::{Db, Program}; mod except_handlers; +mod loop_bindings_visitor; #[derive(Clone, Debug, Default)] struct Loop { /// Flow states at each `break` in the current loop. break_states: Vec, + /// Flow states at each `continue` in the current loop. + continue_states: Vec, } impl Loop { fn push_break(&mut self, state: FlowSnapshot) { self.break_states.push(state); } + + fn push_continue(&mut self, state: FlowSnapshot) { + self.continue_states.push(state); + } } struct ScopeInfo { @@ -694,10 +705,9 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { /// and the second element is the number of definitions that are now associated with /// `definition_node`. /// - /// This method should only be used when adding a definition associated with a `*` import. - /// All other nodes can only ever be associated with exactly 1 or 0 [`Definition`]s. - /// For any node other than an [`ast::Alias`] representing a `*` import, - /// prefer to use `self.add_definition()`, which ensures that this invariant is maintained. + /// Most AST nodes can only be associated with at most one [`Definition`]. Generally prefer + /// `add_definition` above, which enforces that. This method should currently only be used with + /// `*` imports and loop headers. fn push_additional_definition( &mut self, place: ScopedPlaceId, @@ -707,6 +717,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { // Note `definition_node` is guaranteed to be a child of `self.module` let kind = definition_node.into_owned(self.module); + let is_loop_header = kind.is_loop_header(); let category = kind.category(self.source_type.is_stub(), self.module); let is_reexported = kind.is_reexported(); @@ -726,7 +737,15 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { definitions.len() }; - if category.is_binding() { + // We need to avoid marking places as bound as soon as we encounter a loop header + // definition for them, because that would lead to false-positive semantic syntax errors in + // cases like this: + // ```py + // while True: + // global x # [invalid-syntax] if `x` is already used or bound + // x = 1 + // ``` + if category.is_binding() && !is_loop_header { self.mark_place_bound(place); } if category.is_declaration() { @@ -741,8 +760,16 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { } DefinitionCategory::Declaration => use_def.record_declaration(place, definition), DefinitionCategory::Binding => { - use_def.record_binding(place, definition); - self.delete_associated_bindings(place); + // Loop-header bindings don't shadow prior bindings. + let previous_definitions = if is_loop_header { + PreviousDefinitions::AreKept + } else { + PreviousDefinitions::AreShadowed + }; + use_def.record_binding(place, definition, previous_definitions); + if !is_loop_header { + self.delete_associated_bindings(place); + } } } @@ -812,6 +839,65 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { } } + /// Create loop header definitions for all places that are bound within a loop. Return the + /// `LoopToken` referenced by those definitions, and the set of bound place IDs. + fn synthesize_loop_header_definitions( + &mut self, + loop_stmt: LoopStmtRef<'ast>, + bound_places: Vec, + ) -> (LoopToken<'db>, FxHashSet) { + let loop_token = LoopToken::new(self.db); + let mut bound_place_ids: FxHashSet = FxHashSet::default(); + for place_expr in bound_places { + let place_id = self.add_place(place_expr); + if bound_place_ids.insert(place_id) { + let loop_header_ref = LoopHeaderDefinitionNodeRef { + loop_stmt, + place: place_id, + loop_token, + }; + // Note that `DefinitionKind::LoopHeader` doesn't shadow prior bindings. + self.push_additional_definition(place_id, loop_header_ref); + } + } + (loop_token, bound_place_ids) + } + + /// Build a `LoopHeader` that tracks all the variables bound in a loop, which will be visible + /// to uses in the same loop via "loop header definitions". We call this after merging control + /// flow from all the loop-back edges, most importantly at the end of the loop body, and also + /// at any `continue` statements. + fn populate_loop_header( + &mut self, + loop_header_places: &FxHashSet, + loop_token: LoopToken<'db>, + ) { + let mut loop_header = LoopHeader::new(); + let use_def = self.current_use_def_map_mut(); + // Collect bindings. + for place_id in loop_header_places { + for live_binding in use_def.loop_back_bindings(*place_id) { + loop_header.add_binding(*place_id, live_binding); + } + } + // Mark the reachability and narrowing constraints as used. + for place_id in loop_header_places { + for live_binding in loop_header.bindings_for_place(*place_id) { + use_def + .reachability_constraints + .mark_used(live_binding.reachability_constraint); + use_def + .reachability_constraints + .mark_used(live_binding.narrowing_constraint); + } + } + // The `LoopHeader` needs to be visible to uses within the loop body that we've already + // walked, but all our Salsa state is generally immutable. `specify` is how we work around + // that. See this section of the Salsa docs: + // + get_loop_header::specify(self.db, loop_token, loop_header); + } + fn record_expression_narrowing_constraint( &mut self, predicate_node: &ast::Expr, @@ -2077,15 +2163,36 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { self.in_type_checking_block = is_outer_block_in_type_checking; } - ast::Stmt::While(ast::StmtWhile { - test, - body, - orelse, - range: _, - node_index: _, - }) => { + ast::Stmt::While( + while_stmt @ ast::StmtWhile { + test, + body, + orelse, + range: _, + node_index: _, + }, + ) => { + // Pre-walk the loop to collect all the bound places, then create a loop header + // definition for each bound place. See `struct LoopHeader` for more on this. Loop + // header definitions stash a token to look up the `LoopHeader` later, so that we + // can populate the header lazily. + let bound_places = loop_bindings_visitor::collect_while_loop_bindings(while_stmt); + let mut maybe_loop_header_info = None; + // Avoid allocating a `LoopToken` if there are no bound places in this loop. + if !bound_places.is_empty() { + maybe_loop_header_info = Some(self.synthesize_loop_header_definitions( + LoopStmtRef::While(while_stmt), + bound_places, + )); + } + + // Visit the test expression after creating loop headers, so that loop-back values + // are visible. self.visit_expr(test); + // Take the pre_loop snapshot after visiting the test expression, so that walrus + // bindings in the test (which are always evaluated at least once) remain visible + // after the loop. let pre_loop = self.flow_snapshot(); let (predicate, predicate_id) = self.record_expression_narrowing_constraint(test); self.record_reachability_constraint(predicate); @@ -2094,11 +2201,23 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { self.visit_body(body); let this_loop = self.pop_loop(outer_loop); + // Loop-back bindings include everything that's visible if/when control reaches the + // end of the loop body, and they also include everything that's visible to a + // `continue` statement. Merge the `continue` states before collecting bindings. + for continue_state in this_loop.continue_states { + self.flow_merge(continue_state); + } + + // Collect all the loop-back bindings (including the `continue` states we just + // merged) and populate the `LoopHeader`. + if let Some((loop_token, bound_place_ids)) = maybe_loop_header_info { + self.populate_loop_header(&bound_place_ids, loop_token); + } + // We execute the `else` branch once the condition evaluates to false. This could // happen without ever executing the body, if the condition is false the first time // it's tested. Or it could happen if a _later_ evaluation of the condition yields // false. So we merge in the pre-loop state here into the post-body state: - self.flow_merge(pre_loop); // The `else` branch can only be reached if the loop condition *can* be false. To @@ -2170,12 +2289,39 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { let pre_loop = self.flow_snapshot(); + // Pre-walk the loop to collect all the bound places, then create a loop header + // definition for each bound place. See `struct LoopHeader` for more on this. Loop + // header definitions stash a token to look up the `LoopHeader` later, so that we + // can populate the header lazily. + let bound_places = loop_bindings_visitor::collect_for_loop_bindings(for_stmt); + let mut maybe_loop_header_info = None; + // Avoid allocating a `LoopToken` if there are no bound places in this loop. + if !bound_places.is_empty() { + maybe_loop_header_info = Some(self.synthesize_loop_header_definitions( + LoopStmtRef::For(for_stmt), + bound_places, + )); + } + self.add_unpackable_assignment(&Unpackable::For(for_stmt), target, iter_expr); let outer_loop = self.push_loop(); self.visit_body(body); let this_loop = self.pop_loop(outer_loop); + // Loop-back bindings include everything that's visible if/when control reaches the + // end of the loop body, and they also include everything that's visible to a + // `continue` statement. Merge the `continue` states before collecting bindings. + for continue_state in this_loop.continue_states { + self.flow_merge(continue_state); + } + + // Collect all the loop-back bindings (including the `continue` states we just + // merged) and populate the `LoopHeader`. + if let Some((loop_token, bound_place_ids)) = maybe_loop_header_info { + self.populate_loop_header(&bound_place_ids, loop_token); + } + // We may execute the `else` clause without ever executing the body, so merge in // the pre-loop state before visiting `else`. self.flow_merge(pre_loop); @@ -2413,12 +2559,21 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { self.visit_body(finalbody); } - ast::Stmt::Raise(_) | ast::Stmt::Return(_) | ast::Stmt::Continue(_) => { + ast::Stmt::Raise(_) | ast::Stmt::Return(_) => { walk_stmt(self, stmt); // Everything in the current block after a terminal statement is unreachable. self.mark_unreachable(); } + ast::Stmt::Continue(_) => { + let snapshot = self.flow_snapshot(); + if let Some(current_loop) = self.current_loop_mut() { + current_loop.push_continue(snapshot); + } + // Everything in the current block after a terminal statement is unreachable. + self.mark_unreachable(); + } + ast::Stmt::Break(_) => { let snapshot = self.flow_snapshot(); if let Some(current_loop) = self.current_loop_mut() { diff --git a/crates/ty_python_semantic/src/semantic_index/builder/loop_bindings_visitor.rs b/crates/ty_python_semantic/src/semantic_index/builder/loop_bindings_visitor.rs new file mode 100644 index 0000000000000..74918572a81b4 --- /dev/null +++ b/crates/ty_python_semantic/src/semantic_index/builder/loop_bindings_visitor.rs @@ -0,0 +1,477 @@ +use ruff_python_ast as ast; +use ruff_python_ast::visitor::{Visitor, walk_expr, walk_pattern, walk_stmt}; + +use crate::semantic_index::place::PlaceExpr; +use crate::semantic_index::symbol::Symbol; + +/// Do a pre-walk of a `while` loop to collect all the places that are bound, prior to visiting the +/// loop with `SemanticIndexBuilder`. This walk includes bindings in nested loops, but not in +/// nested scopes. (I.e. we don't descend into function bodies or class definitions.) We need this +/// pre-walk so that we can synthesize "loop header definitions" that are visible to the loop body +/// (and condition). See `LoopHeader`. +/// TODO: Handle `nonlocal` bindings from nested scopes somehow. +pub(crate) fn collect_while_loop_bindings(while_stmt: &ast::StmtWhile) -> Vec { + let mut collector = LoopBindingsVisitor::default(); + collector.visit_expr(&while_stmt.test); + collector.visit_body(&while_stmt.body); + collector.bound_places +} + +/// Like `collect_while_loop_bindings` above, but for `for` loops. +pub(crate) fn collect_for_loop_bindings(for_stmt: &ast::StmtFor) -> Vec { + let mut collector = LoopBindingsVisitor::default(); + collector.add_place_from_target(&for_stmt.target); + collector.visit_body(&for_stmt.body); + collector.bound_places +} + +/// The visitor that powers `collect_while_loop_bindings` and `collect_for_loop_bindings`. +/// +/// This visitor doesn't walk nested function/class definitions since those are different scopes. +#[derive(Debug, Default)] +pub(crate) struct LoopBindingsVisitor { + bound_places: Vec, +} + +impl LoopBindingsVisitor { + pub(crate) fn add_place_from_target(&mut self, target: &ast::Expr) { + match target { + ast::Expr::Name(name) => { + self.bound_places.push(PlaceExpr::from_expr_name(name)); + } + ast::Expr::Attribute(_) | ast::Expr::Subscript(_) => { + if let Some(place) = PlaceExpr::try_from_expr(target) { + self.bound_places.push(place); + } + } + ast::Expr::Tuple(tuple) => { + for elt in &tuple.elts { + self.add_place_from_target(elt); + } + } + ast::Expr::List(list) => { + for elt in &list.elts { + self.add_place_from_target(elt); + } + } + ast::Expr::Starred(starred) => { + self.add_place_from_target(&starred.value); + } + _ => {} + } + } +} + +impl<'ast> Visitor<'ast> for LoopBindingsVisitor { + fn visit_stmt(&mut self, stmt: &'ast ast::Stmt) { + match stmt { + ast::Stmt::Assign(node) => { + for target in &node.targets { + self.add_place_from_target(target); + } + // Visit the value expression to find named expressions (walrus operator). + self.visit_expr(&node.value); + } + ast::Stmt::AugAssign(node) => { + self.add_place_from_target(&node.target); + self.visit_expr(&node.value); + } + ast::Stmt::AnnAssign(node) => { + if let Some(value) = &node.value { + self.add_place_from_target(&node.target); + self.visit_expr(value); + } + } + ast::Stmt::For(node) => { + self.add_place_from_target(&node.target); + self.visit_expr(&node.iter); + self.visit_body(&node.body); + self.visit_body(&node.orelse); + } + ast::Stmt::While(node) => { + self.visit_expr(&node.test); + self.visit_body(&node.body); + self.visit_body(&node.orelse); + } + ast::Stmt::With(node) => { + for item in &node.items { + self.visit_expr(&item.context_expr); + if let Some(vars) = &item.optional_vars { + self.add_place_from_target(vars); + } + } + self.visit_body(&node.body); + } + ast::Stmt::Try(node) => { + self.visit_body(&node.body); + for handler in &node.handlers { + let ast::ExceptHandler::ExceptHandler(h) = handler; + if let Some(name) = &h.name { + self.bound_places + .push(PlaceExpr::Symbol(Symbol::new(name.id.clone()))); + } + self.visit_body(&h.body); + } + self.visit_body(&node.orelse); + self.visit_body(&node.finalbody); + } + ast::Stmt::Import(node) => { + for alias in &node.names { + let name = alias.asname.as_ref().unwrap_or(&alias.name); + self.bound_places + .push(PlaceExpr::Symbol(Symbol::new(name.id.clone()))); + } + } + ast::Stmt::ImportFrom(node) => { + for alias in &node.names { + if &*alias.name != "*" { + let name = alias.asname.as_ref().unwrap_or(&alias.name); + self.bound_places + .push(PlaceExpr::Symbol(Symbol::new(name.id.clone()))); + } + } + } + ast::Stmt::FunctionDef(node) => { + self.bound_places + .push(PlaceExpr::Symbol(Symbol::new(node.name.id.clone()))); + // Don't descend into function bodies - they're different scopes. + } + ast::Stmt::ClassDef(node) => { + self.bound_places + .push(PlaceExpr::Symbol(Symbol::new(node.name.id.clone()))); + // Don't descend into class bodies - they're different scopes. + } + ast::Stmt::Match(node) => { + self.visit_expr(&node.subject); + for case in &node.cases { + if let Some(guard) = &case.guard { + self.visit_expr(guard); + } + self.visit_pattern(&case.pattern); + self.visit_body(&case.body); + } + } + ast::Stmt::Delete(node) => { + for target in &node.targets { + self.add_place_from_target(target); + } + } + _ => walk_stmt(self, stmt), + } + } + + fn visit_expr(&mut self, expr: &'ast ast::Expr) { + // the walrus operator + if let ast::Expr::Named(node) = expr { + self.add_place_from_target(&node.target); + } + walk_expr(self, expr); + } + + fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) { + match pattern { + ast::Pattern::MatchAs(p) => { + if let Some(name) = &p.name { + self.bound_places + .push(PlaceExpr::Symbol(Symbol::new(name.id.clone()))); + } + } + ast::Pattern::MatchStar(p) => { + if let Some(name) = &p.name { + self.bound_places + .push(PlaceExpr::Symbol(Symbol::new(name.id.clone()))); + } + } + ast::Pattern::MatchMapping(p) => { + if let Some(rest) = &p.rest { + self.bound_places + .push(PlaceExpr::Symbol(Symbol::new(rest.id.clone()))); + } + } + _ => {} + } + walk_pattern(self, pattern); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ruff_python_parser::parse_module; + use ruff_python_trivia::textwrap::dedent; + + // Test collecting `while` loop bindings. + + fn collect_while_loop_place_names(code: &str) -> Vec { + let parsed = parse_module(code).expect("valid Python code"); + let stmt = &parsed.suite()[0]; + let ast::Stmt::While(while_stmt) = stmt else { + panic!("Expected a while statement"); + }; + collect_while_loop_bindings(while_stmt) + .into_iter() + .map(|place| match place { + PlaceExpr::Symbol(sym) => sym.name().to_string(), + PlaceExpr::Member(member) => member.to_string(), + }) + .collect() + } + + #[test] + fn test_collect_while_loop() { + let bindings = collect_while_loop_place_names(&dedent( + " + while True: + x = 1 + y = 2 + x = 3 + else: + z = 4 + ", + )); + // `z` is not collected, because it's not visible to the loopback edge. + assert_eq!(bindings, vec!["x", "y", "x"]); + } + + #[test] + fn test_collect_while_loop_nested() { + let bindings = collect_while_loop_place_names(&dedent( + " + while True: + a = 1 + if some_condition: + b = 2 + while some_condition: + c = 3 + for d in e: + f = 4 + [g := 42 for x in [h := 99 for _ in 'hello world']] + ", + )); + // Note that "x", the comprehension variable, is not included, but "g", a walrus assignment + // within the comprehension, is included. + assert_eq!(bindings, vec!["a", "b", "c", "d", "f", "h", "g"]); + } + + #[test] + fn test_collect_while_loop_walrus_in_condition() { + let bindings = collect_while_loop_place_names(&dedent( + " + while (x := get_next()): + y = x + 1 + ", + )); + assert_eq!(bindings, vec!["x", "y"]); + } + + // Test collecting `for` loop bindings. + + fn collect_for_loop_place_names(code: &str) -> Vec { + let parsed = parse_module(code).expect("valid Python code"); + let stmt = &parsed.suite()[0]; + let ast::Stmt::For(for_stmt) = stmt else { + panic!("Expected a for statement"); + }; + collect_for_loop_bindings(for_stmt) + .into_iter() + .map(|place| match place { + PlaceExpr::Symbol(sym) => sym.name().to_string(), + PlaceExpr::Member(member) => member.to_string(), + }) + .collect() + } + + #[test] + fn test_collect_for_loop() { + let bindings = collect_for_loop_place_names(&dedent( + " + for i in range(10): + x = 1 + y = 2 + x = 3 + else: + z = 4 + ", + )); + // `z` is not collected, because it's not visible to the loopback edge. + assert_eq!(bindings, vec!["i", "x", "y", "x"]); + } + + #[test] + fn test_collect_for_loop_nested() { + let bindings = collect_for_loop_place_names(&dedent( + " + for i in range(10): + a = 1 + if some_condition: + b = 2 + while some_condition: + c = 3 + for d in e: + f = 4 + [g := 42 for x in [h := 99 for _ in 'hello world']] + ", + )); + // Note that "x", the comprehension variable, is not included, but "g", a walrus assignment + // within the comprehension, is included. + assert_eq!(bindings, vec!["i", "a", "b", "c", "d", "f", "h", "g"]); + } + + /// `LoopBindingsVisitor` has to handle a lot of different types of bindings. Exercise all of + /// them at least once. + #[test] + fn test_all_different_binding_kinds() { + enum LoopKind { + While, + For, + } + let loop_cases = [ + ("while True:", LoopKind::While), + ("for for_loop_var in range(1_000_000):", LoopKind::For), + ("async for for_loop_var in range(1_000_000):", LoopKind::For), + ]; + for (loop_header, loop_kind) in loop_cases { + let code_snippet = dedent(&format!( + r#" + {loop_header} + simple_assign = 1 + tuple_unpack_a, tuple_unpack_b = (1, 2) + [list_unpack_a, list_unpack_b] = [1, 2] + first, *starred_rest, last = [1, 2, 3, 4] + obj.attr_target = 1 + obj["subscript_target"] = 1 + aug_assign += 1 + ann_assign: int = 1 + for for_target in items: + for_body_binding = 1 + while condition: + while_body_binding = 1 + with ctx() as with_var: + with_body_binding = 1 + with ctx() as (with_tuple_a, with_tuple_b): + pass + async with ctx() as async_with_var: + async_with_body_binding = 1 + try: + try_body_binding = 1 + except Exception as exc_var: + except_body_binding = 1 + finally: + finally_binding = 1 + import mod_a + import mod_b as mod_b_alias + from pkg import name_c + from pkg import name_d as name_d_alias + def func_def(): ... + class ClassDef: ... + (walrus_var := 42) + assign_with_walrus = (walrus_in_assign := 1) + aug_assign_walrus += (walrus_in_aug_assign := 1) + ann_assign_walrus: int = (walrus_in_ann_assign := 1) + for walrus_for_target in (walrus_in_for_iter := items): + walrus_for_body = 1 + with (walrus_in_with_ctx := ctx()) as walrus_with_var: + walrus_with_body = 1 + match (walrus_in_match_subject := value): + case match_as_var: + match_as_body = 1 + case _ if (walrus_in_match_guard := guard()): + match_guard_body = 1 + case int() as match_as_with_pattern: ... + case [seq_first, *match_star_rest, seq_last]: ... + case {{"key": mapping_val, **match_mapping_rest}}: ... + case Point(class_pos_x, y=class_kw_y): ... + case match_or_a | match_or_b: ... + case [seq_a, seq_b]: ... + case 42 | None | True: ... + del deleted_variable + [list_comp_iter for list_comp_iter in range(10)] + {{set_comp_iter for set_comp_iter in range(10)}} + (gen_comp_iter for gen_comp_iter in range(10)) + {{dk: dv for dk, dv in items}} + [walrus_in_list_comp := 42 for _ in range(10)] + [a for a in (walrus_in_comp_iter := range(10))] + "#, + )) + .into_owned(); + + let mut expected_bindings = vec![ + "simple_assign", + "tuple_unpack_a", + "tuple_unpack_b", + "list_unpack_a", + "list_unpack_b", + "first", + "starred_rest", + "last", + "obj.attr_target", + "obj[\"subscript_target\"]", + "aug_assign", + "ann_assign", + "for_target", + "for_body_binding", + "while_body_binding", + "with_var", + "with_body_binding", + "with_tuple_a", + "with_tuple_b", + "async_with_var", + "async_with_body_binding", + "try_body_binding", + "exc_var", + "except_body_binding", + "finally_binding", + "mod_a", + "mod_b_alias", + "name_c", + "name_d_alias", + "func_def", + "ClassDef", + "walrus_var", + "assign_with_walrus", + "walrus_in_assign", + "aug_assign_walrus", + "walrus_in_aug_assign", + "ann_assign_walrus", + "walrus_in_ann_assign", + "walrus_for_target", + "walrus_in_for_iter", + "walrus_for_body", + "walrus_in_with_ctx", + "walrus_with_var", + "walrus_with_body", + "walrus_in_match_subject", + "match_as_var", + "match_as_body", + "walrus_in_match_guard", + "match_guard_body", + "match_as_with_pattern", + "seq_first", + "match_star_rest", + "seq_last", + "match_mapping_rest", + "mapping_val", + "class_pos_x", + "class_kw_y", + "match_or_a", + "match_or_b", + "seq_a", + "seq_b", + "deleted_variable", + // Only the LHS of walrus operators gets collected from comprehensions. + "walrus_in_list_comp", + "walrus_in_comp_iter", + ]; + if matches!(loop_kind, LoopKind::For) { + expected_bindings.insert(0, "for_loop_var"); + } + + let bindings = match loop_kind { + LoopKind::While => collect_while_loop_place_names(&code_snippet), + LoopKind::For => collect_for_loop_place_names(&code_snippet), + }; + + assert_eq!(bindings, expected_bindings); + } + } +} diff --git a/crates/ty_python_semantic/src/semantic_index/definition.rs b/crates/ty_python_semantic/src/semantic_index/definition.rs index d6df3b3d20867..073a3033e7797 100644 --- a/crates/ty_python_semantic/src/semantic_index/definition.rs +++ b/crates/ty_python_semantic/src/semantic_index/definition.rs @@ -10,6 +10,7 @@ use ruff_text_size::{Ranged, TextRange, TextSize}; use crate::Db; use crate::ast_node_ref::AstNodeRef; use crate::node_key::NodeKey; +use crate::semantic_index::LoopToken; use crate::semantic_index::place::ScopedPlaceId; use crate::semantic_index::scope::{FileScopeId, ScopeId}; use crate::semantic_index::symbol::ScopedSymbolId; @@ -286,6 +287,7 @@ pub(crate) enum DefinitionNodeRef<'ast, 'db> { TypeVar(&'ast ast::TypeParamTypeVar), ParamSpec(&'ast ast::TypeParamParamSpec), TypeVarTuple(&'ast ast::TypeParamTypeVarTuple), + LoopHeader(LoopHeaderDefinitionNodeRef<'ast, 'db>), } impl<'ast> From<&'ast ast::StmtFunctionDef> for DefinitionNodeRef<'ast, '_> { @@ -336,6 +338,12 @@ impl<'ast> From<&'ast ast::TypeParamTypeVarTuple> for DefinitionNodeRef<'ast, '_ } } +impl<'ast, 'db> From> for DefinitionNodeRef<'ast, 'db> { + fn from(value: LoopHeaderDefinitionNodeRef<'ast, 'db>) -> Self { + Self::LoopHeader(value) + } +} + impl<'ast> From> for DefinitionNodeRef<'ast, '_> { fn from(node_ref: ImportDefinitionNodeRef<'ast>) -> Self { Self::Import(node_ref) @@ -479,6 +487,19 @@ pub(crate) struct ExceptHandlerDefinitionNodeRef<'ast> { pub(crate) is_star: bool, } +#[derive(Copy, Clone, Debug)] +pub(crate) struct LoopHeaderDefinitionNodeRef<'ast, 'db> { + pub(crate) loop_stmt: LoopStmtRef<'ast>, + pub(crate) place: ScopedPlaceId, + pub(crate) loop_token: LoopToken<'db>, +} + +#[derive(Copy, Clone, Debug)] +pub(crate) enum LoopStmtRef<'ast> { + While(&'ast ast::StmtWhile), + For(&'ast ast::StmtFor), +} + #[derive(Copy, Clone, Debug)] pub(crate) struct ComprehensionDefinitionNodeRef<'ast, 'db> { pub(crate) unpack: Option<(UnpackPosition, Unpack<'db>)>, @@ -648,6 +669,18 @@ impl<'db> DefinitionNodeRef<'_, 'db> { DefinitionNodeRef::TypeVarTuple(node) => { DefinitionKind::TypeVarTuple(AstNodeRef::new(parsed, node)) } + DefinitionNodeRef::LoopHeader(LoopHeaderDefinitionNodeRef { + loop_stmt, + place, + loop_token, + }) => DefinitionKind::LoopHeader(LoopHeaderDefinitionKind { + loop_token, + loop_stmt: match loop_stmt { + LoopStmtRef::While(stmt) => LoopStmtKind::While(AstNodeRef::new(parsed, stmt)), + LoopStmtRef::For(stmt) => LoopStmtKind::For(AstNodeRef::new(parsed, stmt)), + }, + place, + }), } } @@ -715,6 +748,10 @@ impl<'db> DefinitionNodeRef<'_, 'db> { Self::TypeVar(node) => node.into(), Self::ParamSpec(node) => node.into(), Self::TypeVarTuple(node) => node.into(), + Self::LoopHeader(LoopHeaderDefinitionNodeRef { loop_stmt, .. }) => match loop_stmt { + LoopStmtRef::While(stmt) => stmt.into(), + LoopStmtRef::For(stmt) => stmt.into(), + }, } } } @@ -786,6 +823,7 @@ pub enum DefinitionKind<'db> { TypeVar(AstNodeRef), ParamSpec(AstNodeRef), TypeVarTuple(AstNodeRef), + LoopHeader(LoopHeaderDefinitionKind<'db>), } impl DefinitionKind<'_> { @@ -830,6 +868,10 @@ impl DefinitionKind<'_> { matches!(self, DefinitionKind::Function(_)) } + pub(crate) const fn is_loop_header(&self) -> bool { + matches!(self, DefinitionKind::LoopHeader(_)) + } + /// Returns the [`TextRange`] of the definition target. /// /// A definition target would mainly be the node representing the place being defined i.e., @@ -871,6 +913,7 @@ impl DefinitionKind<'_> { DefinitionKind::TypeVarTuple(type_var_tuple) => { type_var_tuple.node(module).name.range() } + DefinitionKind::LoopHeader(loop_header) => loop_header.range(module), } } @@ -919,6 +962,7 @@ impl DefinitionKind<'_> { DefinitionKind::TypeVar(type_var) => type_var.node(module).range(), DefinitionKind::ParamSpec(param_spec) => param_spec.node(module).range(), DefinitionKind::TypeVarTuple(type_var_tuple) => type_var_tuple.node(module).range(), + DefinitionKind::LoopHeader(loop_header) => loop_header.range(module), } } @@ -975,7 +1019,8 @@ impl DefinitionKind<'_> { | DefinitionKind::WithItem(_) | DefinitionKind::MatchPattern(_) | DefinitionKind::ImportFromSubmodule(_) - | DefinitionKind::ExceptHandler(_) => DefinitionCategory::Binding, + | DefinitionKind::ExceptHandler(_) + | DefinitionKind::LoopHeader(_) => DefinitionCategory::Binding, } } @@ -1308,6 +1353,39 @@ impl ExceptHandlerDefinitionKind { } } +/// Definition kind for a loop header entry. +#[derive(Clone, Debug, get_size2::GetSize)] +pub struct LoopHeaderDefinitionKind<'db> { + /// The `LoopHeader` struct isn't ready when this type of definition is created. Instead we + /// look it up later by passing this token to `get_loop_header`. + loop_token: LoopToken<'db>, + loop_stmt: LoopStmtKind, + place: ScopedPlaceId, +} + +#[derive(Clone, Debug, get_size2::GetSize)] +pub(crate) enum LoopStmtKind { + While(AstNodeRef), + For(AstNodeRef), +} + +impl<'db> LoopHeaderDefinitionKind<'db> { + pub(crate) fn loop_token(&self) -> LoopToken<'db> { + self.loop_token + } + + pub(crate) fn place(&self) -> ScopedPlaceId { + self.place + } + + pub(crate) fn range(&self, module: &ParsedModuleRef) -> TextRange { + match &self.loop_stmt { + LoopStmtKind::While(stmt) => stmt.node(module).range(), + LoopStmtKind::For(stmt) => stmt.node(module).range(), + } + } +} + #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, salsa::Update, get_size2::GetSize)] pub(crate) struct DefinitionNodeKey(NodeKey); @@ -1377,6 +1455,18 @@ impl From<&ast::StmtAugAssign> for DefinitionNodeKey { } } +impl From<&ast::StmtWhile> for DefinitionNodeKey { + fn from(node: &ast::StmtWhile) -> Self { + Self(NodeKey::from_node(node)) + } +} + +impl From<&ast::StmtFor> for DefinitionNodeKey { + fn from(node: &ast::StmtFor) -> Self { + Self(NodeKey::from_node(node)) + } +} + impl From<&ast::Parameter> for DefinitionNodeKey { fn from(node: &ast::Parameter) -> Self { Self(NodeKey::from_node(node)) diff --git a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs index 5761e95249be1..3776218e8d4f3 100644 --- a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs @@ -231,7 +231,7 @@ use crate::types::{ /// /// reachability constraints are normalized, so equivalent constraints are guaranteed to have equal /// IDs. -#[derive(Clone, Copy, Eq, Hash, PartialEq, get_size2::GetSize)] +#[derive(Clone, Copy, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)] pub(crate) struct ScopedReachabilityConstraintId(u32); impl std::fmt::Debug for ScopedReachabilityConstraintId { diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index 4a32870bafa4a..c0597e87a8e51 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -260,13 +260,16 @@ use crate::semantic_index::scope::{FileScopeId, ScopeKind, ScopeLaziness}; use crate::semantic_index::symbol::ScopedSymbolId; use crate::semantic_index::use_def::place_state::{ Bindings, Declarations, EnclosingSnapshot, LiveBindingsIterator, LiveDeclaration, - LiveDeclarationsIterator, PlaceState, PreviousDefinitions, ScopedDefinitionId, + LiveDeclarationsIterator, PlaceState, }; use crate::semantic_index::{EnclosingSnapshotResult, SemanticIndex}; use crate::types::{PossiblyNarrowedPlaces, Truthiness, Type}; mod place_state; +pub(super) use place_state::PreviousDefinitions; +pub(crate) use place_state::{LiveBinding, ScopedDefinitionId}; + /// Applicable definitions and constraints for every use of a name. #[derive(Debug, PartialEq, Eq, salsa::Update, get_size2::GetSize)] pub(crate) struct UseDefMap<'db> { @@ -389,7 +392,7 @@ impl<'db> UseDefMap<'db> { } } - pub(super) fn is_reachable( + pub(crate) fn is_reachable( &self, db: &dyn crate::Db, reachability: ScopedReachabilityConstraintId, @@ -399,6 +402,21 @@ impl<'db> UseDefMap<'db> { .may_be_true() } + pub(crate) fn definition(&self, id: ScopedDefinitionId) -> DefinitionState<'db> { + self.all_definitions[id] + } + + pub(crate) fn narrowing_evaluator( + &self, + constraint: ScopedNarrowingConstraint, + ) -> NarrowingEvaluator<'_, 'db> { + NarrowingEvaluator { + constraint, + predicates: &self.predicates, + reachability_constraints: &self.reachability_constraints, + } + } + /// Check whether or not a given expression is reachable from the start of the scope. This /// is a local analysis which does not capture the possibility that the entire scope might /// be unreachable. Use [`super::SemanticIndex::is_node_reachable`] for the global @@ -690,7 +708,7 @@ type EnclosingSnapshots = IndexVec #[derive(Clone, Debug)] pub(crate) struct BindingWithConstraintsIterator<'map, 'db> { - all_definitions: &'map IndexVec>, + pub(crate) all_definitions: &'map IndexVec>, pub(crate) predicates: &'map Predicates<'db>, pub(crate) reachability_constraints: &'map ReachabilityConstraints, pub(crate) boundness_analysis: BoundnessAnalysis, @@ -919,7 +937,12 @@ impl<'db> UseDefMapBuilder<'db> { } } - pub(super) fn record_binding(&mut self, place: ScopedPlaceId, binding: Definition<'db>) { + pub(super) fn record_binding( + &mut self, + place: ScopedPlaceId, + binding: Definition<'db>, + previous_definitions: PreviousDefinitions, + ) { let bindings = match place { ScopedPlaceId::Symbol(symbol) => self.symbol_states[symbol].bindings(), ScopedPlaceId::Member(member) => self.member_states[member].bindings(), @@ -935,11 +958,13 @@ impl<'db> UseDefMapBuilder<'db> { }; self.declarations_by_binding .insert(binding, place_state.declarations().clone()); + place_state.record_binding( def_id, self.reachability, self.is_class_scope, place.is_symbol(), + previous_definitions, ); let bindings = match place { @@ -1219,6 +1244,7 @@ impl<'db> UseDefMapBuilder<'db> { self.reachability, self.is_class_scope, place.is_symbol(), + PreviousDefinitions::AreShadowed, ); let reachable_definitions = match place { @@ -1252,6 +1278,7 @@ impl<'db> UseDefMapBuilder<'db> { self.reachability, self.is_class_scope, place.is_symbol(), + PreviousDefinitions::AreShadowed, ); } @@ -1335,6 +1362,21 @@ impl<'db> UseDefMapBuilder<'db> { } } + /// Get a snapshot of the current bindings for a place. We use this at the end of loop bodies + /// to populate the loop header definitions (bindings in the loop body that are visible via + /// loop-back to prior uses in the loop body and also to the loop condition). + pub(super) fn loop_back_bindings( + &self, + place: ScopedPlaceId, + ) -> impl Iterator + '_ { + let bindings = match place { + ScopedPlaceId::Symbol(symbol) => self.symbol_states[symbol].bindings(), + ScopedPlaceId::Member(member) => self.member_states[member].bindings(), + }; + + bindings.iter().copied() + } + /// Restore the current builder places state to the given snapshot. pub(super) fn restore(&mut self, snapshot: FlowSnapshot) { // We never remove places from `place_states` (it's an IndexVec, and the place diff --git a/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs b/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs index 033f0aa5426d3..71833f3406397 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs @@ -53,8 +53,8 @@ use crate::semantic_index::reachability_constraints::{ /// A newtype-index for a definition in a particular scope. #[newtype_index] -#[derive(Ord, PartialOrd, get_size2::GetSize)] -pub(super) struct ScopedDefinitionId; +#[derive(Ord, PartialOrd, salsa::Update, get_size2::GetSize)] +pub(crate) struct ScopedDefinitionId; impl ScopedDefinitionId { /// A special ID that is used to describe an implicit start-of-scope state. When @@ -62,9 +62,9 @@ impl ScopedDefinitionId { /// unbound or undeclared at a given usage site. /// When creating a use-def-map builder, we always add an empty `DefinitionState::Undefined` definition /// at index 0, so this ID is always present. - pub(super) const UNBOUND: ScopedDefinitionId = ScopedDefinitionId::from_u32(0); + pub(crate) const UNBOUND: ScopedDefinitionId = ScopedDefinitionId::from_u32(0); - fn is_unbound(self) -> bool { + pub(crate) fn is_unbound(self) -> bool { self == Self::UNBOUND } } @@ -87,7 +87,7 @@ pub(super) struct LiveDeclaration { pub(super) type LiveDeclarationsIterator<'a> = std::slice::Iter<'a, LiveDeclaration>; #[derive(Clone, Copy, Debug)] -pub(super) enum PreviousDefinitions { +pub(in crate::semantic_index) enum PreviousDefinitions { AreShadowed, AreKept, } @@ -232,11 +232,11 @@ impl Bindings { } /// One of the live bindings for a single place at some point in control flow. -#[derive(Clone, Debug, PartialEq, Eq, get_size2::GetSize)] -pub(super) struct LiveBinding { - pub(super) binding: ScopedDefinitionId, - pub(super) narrowing_constraint: ScopedNarrowingConstraint, - pub(super) reachability_constraint: ScopedReachabilityConstraintId, +#[derive(Clone, Copy, Debug, PartialEq, Eq, salsa::Update, get_size2::GetSize)] +pub(crate) struct LiveBinding { + pub(crate) binding: ScopedDefinitionId, + pub(crate) narrowing_constraint: ScopedNarrowingConstraint, + pub(crate) reachability_constraint: ScopedReachabilityConstraintId, } pub(super) type LiveBindingsIterator<'a> = std::slice::Iter<'a, LiveBinding>; @@ -380,6 +380,7 @@ impl PlaceState { reachability_constraint: ScopedReachabilityConstraintId, is_class_scope: bool, is_place_name: bool, + previous_definitions: PreviousDefinitions, ) { debug_assert_ne!(binding_id, ScopedDefinitionId::UNBOUND); self.bindings.record_binding( @@ -387,7 +388,7 @@ impl PlaceState { reachability_constraint, is_class_scope, is_place_name, - PreviousDefinitions::AreShadowed, + previous_definitions, ); } @@ -509,6 +510,7 @@ mod tests { ScopedReachabilityConstraintId::ALWAYS_TRUE, false, true, + PreviousDefinitions::AreShadowed, ); assert_bindings(&sym, &[(1, ScopedNarrowingConstraint::ALWAYS_TRUE)]); @@ -523,6 +525,7 @@ mod tests { ScopedReachabilityConstraintId::ALWAYS_TRUE, false, true, + PreviousDefinitions::AreShadowed, ); let atom = reachability_constraints.add_atom(ScopedPredicateId::new(0)); sym.record_narrowing_constraint(&mut reachability_constraints, atom); @@ -541,6 +544,7 @@ mod tests { ScopedReachabilityConstraintId::ALWAYS_TRUE, false, true, + PreviousDefinitions::AreShadowed, ); let atom0 = reachability_constraints.add_atom(ScopedPredicateId::new(0)); sym1a.record_narrowing_constraint(&mut reachability_constraints, atom0); @@ -551,6 +555,7 @@ mod tests { ScopedReachabilityConstraintId::ALWAYS_TRUE, false, true, + PreviousDefinitions::AreShadowed, ); sym1b.record_narrowing_constraint(&mut reachability_constraints, atom0); @@ -566,6 +571,7 @@ mod tests { ScopedReachabilityConstraintId::ALWAYS_TRUE, false, true, + PreviousDefinitions::AreShadowed, ); let atom1 = reachability_constraints.add_atom(ScopedPredicateId::new(1)); sym2a.record_narrowing_constraint(&mut reachability_constraints, atom1); @@ -576,6 +582,7 @@ mod tests { ScopedReachabilityConstraintId::ALWAYS_TRUE, false, true, + PreviousDefinitions::AreShadowed, ); let atom2 = reachability_constraints.add_atom(ScopedPredicateId::new(2)); sym1b.record_narrowing_constraint(&mut reachability_constraints, atom2); @@ -596,6 +603,7 @@ mod tests { ScopedReachabilityConstraintId::ALWAYS_TRUE, false, true, + PreviousDefinitions::AreShadowed, ); let atom3 = reachability_constraints.add_atom(ScopedPredicateId::new(3)); sym3a.record_narrowing_constraint(&mut reachability_constraints, atom3); diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index f7c8234e63998..373065fa86bfb 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -13476,24 +13476,30 @@ pub(crate) mod tests { assert!(!div.is_redundant_with(&db, Type::unknown())); assert!(!Type::unknown().is_redundant_with(&db, div)); - let truthy_div = IntersectionBuilder::new(&db) + // `Divergent & T` and `Divergent & ~T` both simplify to `Divergent`, except for the + // specific case of `Divergent & Never`, which simplifies to `Never`. + let divergent_intersection = IntersectionBuilder::new(&db) .add_positive(div) - .add_negative(Type::AlwaysFalsy) + .add_positive(todo_type!("2")) + .add_negative(todo_type!("3")) .build(); - - let union = UnionType::from_elements(&db, [Type::unknown(), truthy_div]); - assert!(!truthy_div.is_redundant_with(&db, Type::unknown())); - assert_eq!( - union.display(&db).to_string(), - "Unknown | (Divergent & ~AlwaysFalsy)" - ); - - let union = UnionType::from_elements(&db, [truthy_div, Type::unknown()]); - assert!(!Type::unknown().is_redundant_with(&db, truthy_div)); - assert_eq!( - union.display(&db).to_string(), - "(Divergent & ~AlwaysFalsy) | Unknown" - ); + assert_eq!(divergent_intersection, div); + let divergent_intersection = IntersectionBuilder::new(&db) + .add_positive(todo_type!("2")) + .add_negative(todo_type!("3")) + .add_positive(div) + .build(); + assert_eq!(divergent_intersection, div); + let divergent_never_intersection = IntersectionBuilder::new(&db) + .add_positive(div) + .add_positive(Type::Never) + .build(); + assert_eq!(divergent_never_intersection, Type::Never); + let divergent_never_intersection = IntersectionBuilder::new(&db) + .add_positive(Type::Never) + .add_positive(div) + .build(); + assert_eq!(divergent_never_intersection, Type::Never); // The `object` type has a good convergence property, that is, its union with all other types is `object`. // (e.g. `object | tuple[Divergent] == object`, `object | tuple[object] == object`) diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 34bdc816b18e3..1ba0388b32e6e 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -1035,6 +1035,32 @@ struct InnerIntersectionBuilder<'db> { impl<'db> InnerIntersectionBuilder<'db> { /// Adds a positive type to this intersection. fn add_positive(&mut self, db: &'db dyn Db, mut new_positive: Type<'db>) { + // `Never & T` -> `Never` + if self.positive.contains(&Type::Never) { + return; + } + + // `T & Never` -> `Never` + if new_positive.is_never() { + *self = Self::default(); + self.positive.insert(Type::Never); + return; + } + + // `T & Divergent` -> `Divergent`. Conceptually, `Divergent` behaves like `Never` here and + // dominates intersections. However, `Divergent` is actually a dynamic/gradual type, so + // `~Divergent` acts like `Divergent` rather than dropping out like `~Never` does. + // `Divergent` also gets a lot of special handling in cycle recovery. + if new_positive.is_divergent() { + *self = Self::default(); + self.positive.insert(new_positive); + return; + } + // `Divergent & T` -> `Divergent` + if self.positive.iter().any(Type::is_divergent) { + return; + } + match new_positive { // `LiteralString & AlwaysTruthy` -> `LiteralString & ~Literal[""]` Type::AlwaysTruthy if self.positive.contains(&Type::LiteralString) => { @@ -1175,6 +1201,13 @@ impl<'db> InnerIntersectionBuilder<'db> { /// Adds a negative type to this intersection. fn add_negative(&mut self, db: &'db dyn Db, new_negative: Type<'db>) { + // `Divergent & ~T` -> `Divergent`. Note that `~Divergent` becomes `Divergent` via the + // `Type::Dynamic` branch below, so we don't need a special case for that. + if self.positive.iter().any(Type::is_divergent) { + debug_assert_eq!(self.positive.len(), 1, "`Divergent` should be alone"); + return; + } + let contains_bool = || { self.positive .iter() diff --git a/crates/ty_python_semantic/src/types/ide_support.rs b/crates/ty_python_semantic/src/types/ide_support.rs index f64b67e7bea7a..55b430a9ce888 100644 --- a/crates/ty_python_semantic/src/types/ide_support.rs +++ b/crates/ty_python_semantic/src/types/ide_support.rs @@ -1529,7 +1529,8 @@ mod resolve_definition { | DefinitionKind::ExceptHandler(_) | DefinitionKind::TypeVar(_) | DefinitionKind::ParamSpec(_) - | DefinitionKind::TypeVarTuple(_) => { + | DefinitionKind::TypeVarTuple(_) + | DefinitionKind::LoopHeader(_) => { // Not yet implemented return Err(()); } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 479821feca79b..5b1b9d47c4251 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -44,7 +44,7 @@ use crate::semantic_index::ast_ids::{HasScopedUseId, ScopedUseId}; use crate::semantic_index::definition::{ AnnotatedAssignmentDefinitionKind, AssignmentDefinitionKind, ComprehensionDefinitionKind, Definition, DefinitionKind, DefinitionNodeKey, DefinitionState, ExceptHandlerDefinitionKind, - ForStmtDefinitionKind, TargetKind, WithItemDefinitionKind, + ForStmtDefinitionKind, LoopHeaderDefinitionKind, TargetKind, WithItemDefinitionKind, }; use crate::semantic_index::expression::{Expression, ExpressionKind}; use crate::semantic_index::narrowing_constraints::ConstraintKey; @@ -55,8 +55,9 @@ use crate::semantic_index::scope::{ use crate::semantic_index::symbol::{ScopedSymbolId, Symbol}; use crate::semantic_index::{ ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, attribute_assignments, - place_table, + get_loop_header, place_table, }; +use crate::types::builder::RecursivelyDefined; use crate::types::call::bind::{CallableDescription, MatchingOverloadIndex}; use crate::types::call::{Argument, Binding, Bindings, CallArguments, CallError, CallErrorKind}; use crate::types::class::{ @@ -2334,6 +2335,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { DefinitionKind::TypeVarTuple(node) => { self.infer_typevartuple_definition(node.node(self.module()), definition); } + DefinitionKind::LoopHeader(loop_header) => { + self.infer_loop_header_definition(loop_header, definition); + } } } @@ -4946,6 +4950,59 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ); } + /// Infer the type for a loop header definition. + /// + /// The loop header sees all bindings that loop-back, either by reaching the end of the loop + /// body or a `continue` statement. This can include bindings from before the loop too, though + /// that's technically redundant, since the loop header definition itself doesn't shadow those + /// bindings. See `struct LoopHeader` in the semantic index for more on how all this fits + /// together. + fn infer_loop_header_definition( + &mut self, + loop_header_kind: &LoopHeaderDefinitionKind<'db>, + definition: Definition<'db>, + ) { + let db = self.db(); + let loop_token = loop_header_kind.loop_token(); + let place = loop_header_kind.place(); + let loop_header = get_loop_header(db, loop_token); + let use_def = self + .index + .use_def_map(self.scope().file_scope_id(self.db())); + + let mut union = UnionBuilder::new(db).recursively_defined(RecursivelyDefined::Yes); + + for live_binding in loop_header.bindings_for_place(place) { + // Skip unreachable bindings. + if !use_def.is_reachable(db, live_binding.reachability_constraint) { + continue; + } + + // Boundness analysis is handled by looking at these bindings again in + // `place_from_bindings_impl`. Here we're only concerned with the type. + let def_state = use_def.definition(live_binding.binding); + let def = match def_state { + DefinitionState::Defined(def) => def, + DefinitionState::Deleted | DefinitionState::Undefined => continue, + }; + + // This loop header is visible to itself. Filter it out to avoid a pointless cycle. + if def == definition { + continue; + } + + let binding_ty = binding_type(db, def); + let narrowed_ty = use_def + .narrowing_evaluator(live_binding.narrowing_constraint) + .narrow(db, binding_ty, place); + + union.add_in_place(narrowed_ty); + } + + self.bindings + .insert(definition, union.build(), self.multi_inference_state); + } + fn infer_match_statement(&mut self, match_statement: &ast::StmtMatch) { let ast::StmtMatch { range: _, diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index b628d88552bf9..0fb970609488e 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -952,6 +952,28 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ) -> Option> { let op = if is_positive { op } else { op.negate() }; + // `Divergent` shows up as an initial value in cycle recovery. If it appears on either side + // of a potentially narrowing comparison, we don't want it to turn that comparison into a + // no-op (e.g. because `Divergent` is not a singleton in the `IsNot` branch below), because + // that could result in an initial type inference result that's too wide. Then, even if the + // next cycle iteration resolved all the `Divergent` values and correctly narrowed the + // type, we'd be stuck with the too-wide answer from the first iteration, because + // `Type::cycle_normalized` only ever widens and never narrows from one iteration to the + // next (to avoid oscillations). To prevent this, we have `Divergent` "poison" any value + // that's compared to it, so that `Type::cycle_normalized` can see it and skip the widening + // union step. + // + // For an extended discussion of the case that originally encountered this problem, see the + // "`Divergent` in narrowing conditions doesn't run afoul of 'monotonic widening' in cycle + // recovery" mdtest case in `while_loop.md`. See also + // https://github.com/astral-sh/ruff/pull/22794#issuecomment-3852095578. + if lhs_ty.is_divergent() { + return Some(lhs_ty); + } + if rhs_ty.is_divergent() { + return Some(rhs_ty); + } + match op { ast::CmpOp::IsNot => { if rhs_ty.is_singleton(self.db) { @@ -1110,7 +1132,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { // if union["tag"] == "foo": // reveal_type(union) # Foo // - // Importantly, `my_typeddict_union["tag"]` isn't the place we're going to constraint. + // Importantly, `my_typeddict_union["tag"]` isn't the place we're going to constrain. // Instead, we're going to constrain `my_typeddict_union` itself. if matches!(&**ops, [ast::CmpOp::Eq | ast::CmpOp::NotEq]) { // For `==`, we use equality semantics on the `if` branch (is_positive=true).