diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 0a6733db233a9..327893821ddc6 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -143,7 +143,7 @@ impl<'db> SemanticIndexBuilder<'db> { self.current_use_def_map().restore(state); } - fn flow_merge(&mut self, state: FlowSnapshot) { + fn flow_merge(&mut self, state: &FlowSnapshot) { self.current_use_def_map().merge(state); } @@ -393,27 +393,27 @@ where self.visit_expr(&node.test); let pre_if = self.flow_snapshot(); self.visit_body(&node.body); - let mut last_clause_is_else = false; - let mut post_clauses: Vec = vec![self.flow_snapshot()]; + let mut post_clauses: Vec = vec![]; for clause in &node.elif_else_clauses { - // we can only take an elif/else clause if none of the previous ones were taken + // snapshot after every block except the last; the last one will just become + // the state that we merge the other snapshots into + post_clauses.push(self.flow_snapshot()); + // we can only take an elif/else branch if none of the previous ones were + // taken, so the block entry state is always `pre_if` self.flow_restore(pre_if.clone()); self.visit_elif_else_clause(clause); - post_clauses.push(self.flow_snapshot()); - if clause.test.is_none() { - last_clause_is_else = true; - } } - let mut post_clause_iter = post_clauses.into_iter(); - if last_clause_is_else { - // if the last clause was an else, the pre_if state can't directly reach the - // post-state; we must enter one of the clauses. - self.flow_restore(post_clause_iter.next().unwrap()); - } else { - self.flow_restore(pre_if); + for post_clause_state in post_clauses { + self.flow_merge(&post_clause_state); } - for post_clause_state in post_clause_iter { - self.flow_merge(post_clause_state); + let has_else = node + .elif_else_clauses + .last() + .is_some_and(|clause| clause.test.is_none()); + if !has_else { + // if there's no else clause, then it's possible we took none of the branches, + // and the pre_if state can reach here + self.flow_merge(&pre_if); } } _ => { @@ -485,7 +485,7 @@ where let post_body = self.flow_snapshot(); self.flow_restore(pre_if); self.visit_expr(orelse); - self.flow_merge(post_body); + self.flow_merge(&post_body); } _ => { walk_expr(self, expr); diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs index 9e501a30a88f6..79c7ad8a2a61d 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs @@ -253,9 +253,9 @@ impl<'db> UseDefMapBuilder<'db> { /// Restore the current builder visible-definitions state to the given snapshot. pub(super) fn restore(&mut self, snapshot: FlowSnapshot) { - // We never remove symbols from `definitions_by_symbol` (its an IndexVec, and the symbol - // IDs need to line up), so the current number of recorded symbols must always be equal or - // greater than the number of symbols in a previously-recorded snapshot. + // We never remove symbols from `definitions_by_symbol` (it's an IndexVec, and the symbol + // IDs must line up), so the current number of known symbols must always be equal to or + // greater than the number of known symbols in a previously-taken snapshot. let num_symbols = self.definitions_by_symbol.len(); debug_assert!(num_symbols >= snapshot.definitions_by_symbol.len()); @@ -272,8 +272,7 @@ impl<'db> UseDefMapBuilder<'db> { /// Merge the given snapshot into the current state, reflecting that we might have taken either /// path to get here. The new visible-definitions state for each symbol should include /// definitions from both the prior state and the snapshot. - #[allow(clippy::needless_pass_by_value)] - pub(super) fn merge(&mut self, snapshot: FlowSnapshot) { + pub(super) fn merge(&mut self, snapshot: &FlowSnapshot) { // The tricky thing about merging two Ranges pointing into `all_definitions` is that if the // two Ranges aren't already adjacent in `all_definitions`, we will have to copy at least // one or the other of the ranges to the end of `all_definitions` so as to make them @@ -282,48 +281,60 @@ impl<'db> UseDefMapBuilder<'db> { // It's possible we may end up with some old entries in `all_definitions` that nobody is // pointing to, but that's OK. - for (symbol_id, to_merge) in snapshot.definitions_by_symbol.iter_enumerated() { - let current = &mut self.definitions_by_symbol[symbol_id]; + // We never remove symbols from `definitions_by_symbol` (it's an IndexVec, and the symbol + // IDs must line up), so the current number of known symbols must always be equal to or + // greater than the number of known symbols in a previously-taken snapshot. + debug_assert!(self.definitions_by_symbol.len() >= snapshot.definitions_by_symbol.len()); + + for (symbol_id, current) in self.definitions_by_symbol.iter_mut_enumerated() { + let Some(snapshot) = snapshot.definitions_by_symbol.get(symbol_id) else { + // Symbol not present in snapshot, so it's unbound from that path. + current.may_be_unbound = true; + continue; + }; // If the symbol can be unbound in either predecessor, it can be unbound post-merge. - current.may_be_unbound |= to_merge.may_be_unbound; + current.may_be_unbound |= snapshot.may_be_unbound; // Merge the definition ranges. - if current.definitions_range == to_merge.definitions_range { - // Ranges already identical, nothing to do! - } else if current.definitions_range.end == to_merge.definitions_range.start { - // Ranges are adjacent (`current` first), just merge them into one range. - current.definitions_range = - (current.definitions_range.start)..(to_merge.definitions_range.end); - } else if current.definitions_range.start == to_merge.definitions_range.end { - // Ranges are adjacent (`to_merge` first), just merge them into one range. - current.definitions_range = - (to_merge.definitions_range.start)..(current.definitions_range.end); - } else if current.definitions_range.end == self.all_definitions.len() { - // Ranges are not adjacent, `current` is at the end of `all_definitions`, we need - // to copy `to_merge` to the end so they are adjacent and can be merged into one - // range. - self.all_definitions - .extend_from_within(to_merge.definitions_range.clone()); - current.definitions_range.end = self.all_definitions.len(); - } else if to_merge.definitions_range.end == self.all_definitions.len() { - // Ranges are not adjacent, `to_merge` is at the end of `all_definitions`, we need - // to copy `current` to the end so they are adjacent and can be merged into one - // range. - self.all_definitions - .extend_from_within(current.definitions_range.clone()); - current.definitions_range.start = to_merge.definitions_range.start; - current.definitions_range.end = self.all_definitions.len(); + let current = &mut current.definitions_range; + let snapshot = &snapshot.definitions_range; + + // We never create reversed ranges. + debug_assert!(current.end >= current.start); + debug_assert!(snapshot.end >= snapshot.start); + + if current == snapshot { + // Ranges already identical, nothing to do. + } else if snapshot.is_empty() { + // Merging from an empty range; nothing to do. + } else if (*current).is_empty() { + // Merging to an empty range; just use the incoming range. + *current = snapshot.clone(); + } else if snapshot.end >= current.start && snapshot.start <= current.end { + // Ranges are adjacent or overlapping, merge them in-place. + *current = current.start.min(snapshot.start)..current.end.max(snapshot.end); + } else if current.end == self.all_definitions.len() { + // Ranges are not adjacent or overlapping, `current` is at the end of + // `all_definitions`, we need to copy `snapshot` to the end so they are adjacent + // and can be merged into one range. + self.all_definitions.extend_from_within(snapshot.clone()); + current.end = self.all_definitions.len(); + } else if snapshot.end == self.all_definitions.len() { + // Ranges are not adjacent or overlapping, `snapshot` is at the end of + // `all_definitions`, we need to copy `current` to the end so they are adjacent and + // can be merged into one range. + self.all_definitions.extend_from_within(current.clone()); + current.start = snapshot.start; + current.end = self.all_definitions.len(); } else { // Ranges are not adjacent and neither one is at the end of `all_definitions`, we // have to copy both to the end so they are adjacent and we can merge them. let start = self.all_definitions.len(); - self.all_definitions - .extend_from_within(current.definitions_range.clone()); - self.all_definitions - .extend_from_within(to_merge.definitions_range.clone()); - current.definitions_range.start = start; - current.definitions_range.end = self.all_definitions.len(); + self.all_definitions.extend_from_within(current.clone()); + self.all_definitions.extend_from_within(snapshot.clone()); + current.start = start; + current.end = self.all_definitions.len(); } } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index bdc3ec8cce655..9d3c7f40669f3 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1094,6 +1094,87 @@ mod tests { Ok(()) } + #[test] + fn if_elif_else_single_symbol() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + if flag: + y = 1 + elif flag2: + y = 2 + else: + y = 3 + ", + )?; + + assert_public_ty(&db, "src/a.py", "y", "Literal[1, 2, 3]"); + Ok(()) + } + + #[test] + fn if_elif_else_no_definition_in_else() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + y = 0 + if flag: + y = 1 + elif flag2: + y = 2 + else: + pass + ", + )?; + + assert_public_ty(&db, "src/a.py", "y", "Literal[0, 1, 2]"); + Ok(()) + } + + #[test] + fn if_elif_else_no_definition_in_else_one_intervening_definition() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + y = 0 + if flag: + y = 1 + z = 3 + elif flag2: + y = 2 + else: + pass + ", + )?; + + assert_public_ty(&db, "src/a.py", "y", "Literal[0, 1, 2]"); + Ok(()) + } + + #[test] + fn nested_if() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + y = 0 + if flag: + if flag2: + y = 1 + ", + )?; + + assert_public_ty(&db, "src/a.py", "y", "Literal[0, 1]"); + Ok(()) + } + #[test] fn if_elif() -> anyhow::Result<()> { let mut db = setup_db(); diff --git a/crates/ruff_index/src/slice.rs b/crates/ruff_index/src/slice.rs index 804aa1fbda2a6..9b3f9523f7a9c 100644 --- a/crates/ruff_index/src/slice.rs +++ b/crates/ruff_index/src/slice.rs @@ -80,6 +80,13 @@ impl IndexSlice { self.raw.iter_mut() } + #[inline] + pub fn iter_mut_enumerated( + &mut self, + ) -> impl DoubleEndedIterator + ExactSizeIterator + '_ { + self.raw.iter_mut().enumerate().map(|(n, t)| (I::new(n), t)) + } + #[inline] pub fn last_index(&self) -> Option { self.len().checked_sub(1).map(I::new)