From 418413eab024af3ab48e1d51fc4fb9d8697299d5 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sat, 29 Mar 2025 12:44:54 -0400 Subject: [PATCH 1/7] add narrowing from guard expressions --- .../resources/mdtest/narrow/match.md | 46 +++++++++++++++++++ .../src/semantic_index/builder.rs | 10 +++- .../src/types/narrow.rs | 1 - 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md index 27b01efe7b733..fd1b9eb2c9fba 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md @@ -165,3 +165,49 @@ match x: reveal_type(x) # revealed: object ``` + +## Narrowing due to guard + +```py +def get_object() -> object: + return object() + +x = get_object() + +reveal_type(x) # revealed: object + +match x: + case str() | float() if type(x) is str: + reveal_type(x) # revealed: str + case "foo" | 42 | None if isinstance(x, int): + reveal_type(x) # revealed: Literal[42] + case False if x: + reveal_type(x) # revealed: Never + case "foo" if x := "bar": + reveal_type(x) # revealed: Literal["bar"] + +reveal_type(x) # revealed: object +``` + +## Guard and reveal_type in guard + +```py +def get_object() -> object: + return object() + +x = get_object() + +reveal_type(x) # revealed: object + +match x: + case str() | float() if type(x) is str and reveal_type(x): # revealed: str + pass + case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: Literal[42] + pass + case False if x and reveal_type(x): # revealed: Never + pass + case "foo" if (x := "bar") and reveal_type(x): # revealed: Literal["bar"] + pass + +reveal_type(x) # revealed: object +``` diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index f2d2a30224c51..8a5f68d685e0c 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -1590,8 +1590,14 @@ where case.guard.as_deref(), ); self.record_reachability_constraint(predicate); - if let Some(expr) = &case.guard { - self.visit_expr(expr); + if let Some(guard) = &case.guard { + let guard_expr = self.add_standalone_expression(guard); + self.visit_expr(guard); + let predicate = Predicate { + node: PredicateNode::Expression(guard_expr), + is_positive: true, + }; + self.record_narrowing_constraint(predicate); } self.visit_body(&case.body); for id in &vis_constraints { diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index cf5431b47ec80..9860be0474b3c 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -302,7 +302,6 @@ impl<'db> NarrowingConstraintsBuilder<'db> { pattern: PatternPredicate<'db>, ) -> Option> { let subject = pattern.subject(self.db); - self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject) } From 8b444e7bf4b212606d995d3d080e098589f807a6 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Tue, 8 Apr 2025 00:25:50 -0400 Subject: [PATCH 2/7] more match narrowing --- .../resources/mdtest/narrow/match.md | 9 ++-- .../src/semantic_index/builder.rs | 53 ++++++++++++------- .../src/types/narrow.rs | 32 ++++++++++- 3 files changed, 68 insertions(+), 26 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md index fd1b9eb2c9fba..8fd2f7cfdde2a 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md @@ -39,8 +39,7 @@ match x: case A(): reveal_type(x) # revealed: A case B(): - # TODO could be `B & ~A` - reveal_type(x) # revealed: B + reveal_type(x) # revealed: B & ~A reveal_type(x) # revealed: object ``` @@ -88,7 +87,7 @@ match x: case 6.0: reveal_type(x) # revealed: float case 1j: - reveal_type(x) # revealed: complex + reveal_type(x) # revealed: complex & ~float case b"foo": reveal_type(x) # revealed: Literal[b"foo"] @@ -134,11 +133,11 @@ match x: case "foo" | 42 | None: reveal_type(x) # revealed: Literal["foo", 42] | None case "foo" | tuple(): - reveal_type(x) # revealed: Literal["foo"] | tuple + reveal_type(x) # revealed: tuple case True | False: reveal_type(x) # revealed: bool case 3.14 | 2.718 | 1.414: - reveal_type(x) # revealed: float + reveal_type(x) # revealed: float & ~tuple reveal_type(x) # revealed: object ``` diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 8a5f68d685e0c..220825f74e7f2 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -1572,50 +1572,65 @@ where return; } - let after_subject = self.flow_snapshot(); + let has_catchall = cases + .last() + .is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard()); + + let mut no_case_matched = self.flow_snapshot(); let mut vis_constraints = vec![]; let mut post_case_snapshots = vec![]; - for (i, case) in cases.iter().enumerate() { - if i != 0 { - post_case_snapshots.push(self.flow_snapshot()); - self.flow_restore(after_subject.clone()); - } + let mut match_predicate; + for (i, case) in cases.iter().enumerate() { self.current_match_case = Some(CurrentMatchCase::new(&case.pattern)); self.visit_pattern(&case.pattern); self.current_match_case = None; - let predicate = self.add_pattern_narrowing_constraint( + no_case_matched = self.flow_snapshot(); + match_predicate = self.add_pattern_narrowing_constraint( subject_expr, &case.pattern, case.guard.as_deref(), ); - self.record_reachability_constraint(predicate); - if let Some(guard) = &case.guard { + self.record_reachability_constraint(match_predicate); + + let match_success_guard_failure = case.guard.as_ref().map(|guard| { let guard_expr = self.add_standalone_expression(guard); self.visit_expr(guard); + let post_match_success = self.flow_snapshot(); let predicate = Predicate { node: PredicateNode::Expression(guard_expr), is_positive: true, }; + self.record_negated_narrowing_constraint(predicate); + let match_success_guard_failure = self.flow_snapshot(); + self.flow_restore(post_match_success); self.record_narrowing_constraint(predicate); - } + match_success_guard_failure + }); + self.visit_body(&case.body); for id in &vis_constraints { self.record_negated_visibility_constraint(*id); } - let vis_constraint_id = self.record_visibility_constraint(predicate); + let vis_constraint_id = self.record_visibility_constraint(match_predicate); vis_constraints.push(vis_constraint_id); + + post_case_snapshots.push(self.flow_snapshot()); + + if i != cases.len() - 1 || !has_catchall { + // We need to restore the state after each case, but not after the last one. + // The last one will just become the state that we merge the other snapshots into + self.flow_restore(no_case_matched.clone()); + self.record_negated_narrowing_constraint(match_predicate); + if let Some(match_success_guard_failure) = match_success_guard_failure { + self.flow_merge(match_success_guard_failure); + } + } } // If there is no final wildcard match case, pretend there is one. This is similar to how // we add an implicit `else` block in if-elif chains, in case it's not present. - if !cases - .last() - .is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard()) - { - post_case_snapshots.push(self.flow_snapshot()); - self.flow_restore(after_subject.clone()); - + if !has_catchall { for id in &vis_constraints { self.record_negated_visibility_constraint(*id); } @@ -1625,7 +1640,7 @@ where self.flow_merge(post_clause_state); } - self.simplify_visibility_constraints(after_subject); + self.simplify_visibility_constraints(no_case_matched); } ast::Stmt::Try(ast::StmtTry { body, diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 9860be0474b3c..f8ccf8092d687 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -50,7 +50,13 @@ pub(crate) fn infer_narrowing_constraint<'db>( all_negative_narrowing_constraints_for_expression(db, expression) } } - PredicateNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern), + PredicateNode::Pattern(pattern) => { + if predicate.is_positive { + all_narrowing_constraints_for_pattern(db, pattern) + } else { + all_negative_narrowing_constraints_for_pattern(db, pattern) + } + } PredicateNode::StarImportPlaceholder(_) => return None, }; if let Some(constraints) = constraints { @@ -95,6 +101,15 @@ fn all_negative_narrowing_constraints_for_expression<'db>( NarrowingConstraintsBuilder::new(db, PredicateNode::Expression(expression), false).finish() } +#[allow(clippy::ref_option)] +#[salsa::tracked(return_ref)] +fn all_negative_narrowing_constraints_for_pattern<'db>( + db: &'db dyn Db, + pattern: PatternPredicate<'db>, +) -> Option> { + NarrowingConstraintsBuilder::new(db, PredicateNode::Pattern(pattern), false).finish() +} + #[allow(clippy::ref_option)] fn constraints_for_expression_cycle_recover<'db>( _db: &'db dyn Db, @@ -217,6 +232,12 @@ fn merge_constraints_or<'db>( } } +fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, yes: bool) { + for (_symbol, ty) in constraints.iter_mut() { + *ty = ty.negate_if(db, yes); + } +} + struct NarrowingConstraintsBuilder<'db> { db: &'db dyn Db, predicate: PredicateNode<'db>, @@ -237,7 +258,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> { PredicateNode::Expression(expression) => { self.evaluate_expression_predicate(expression, self.is_positive) } - PredicateNode::Pattern(pattern) => self.evaluate_pattern_predicate(pattern), + PredicateNode::Pattern(pattern) => { + self.evaluate_pattern_predicate(pattern, self.is_positive) + } PredicateNode::StarImportPlaceholder(_) => return None, }; if let Some(mut constraints) = constraints { @@ -300,9 +323,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> { fn evaluate_pattern_predicate( &mut self, pattern: PatternPredicate<'db>, + is_positive: bool, ) -> Option> { let subject = pattern.subject(self.db); self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject) + .map(|mut constraints| { + negate_if(&mut constraints, self.db, !is_positive); + constraints + }) } fn symbols(&self) -> Arc { From a9f4456e1579c10ca8bd9771365aff476aec8b04 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 9 Apr 2025 11:42:29 -0400 Subject: [PATCH 3/7] expand comment a bit --- .../red_knot_python_semantic/src/semantic_index/builder.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 220825f74e7f2..04264d6102370 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -1618,8 +1618,10 @@ where post_case_snapshots.push(self.flow_snapshot()); if i != cases.len() - 1 || !has_catchall { - // We need to restore the state after each case, but not after the last one. - // The last one will just become the state that we merge the other snapshots into + // We need to restore the state after each case, but not after the last + // one. The last one will just become the state that we merge the other + // snapshots into. (If there's no catch-all, we'll add an implied one + // below, so this can't be the last case.) self.flow_restore(no_case_matched.clone()); self.record_negated_narrowing_constraint(match_predicate); if let Some(match_success_guard_failure) = match_success_guard_failure { From 0fcba9d1647682a450a36113576893265b3d045c Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Thu, 10 Apr 2025 00:33:51 -0400 Subject: [PATCH 4/7] stop visiting case patterns in the no-case-matched path --- .../src/semantic_index/builder.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 04264d6102370..aa2b2c74f934a 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -1572,11 +1572,12 @@ where return; } + let mut no_case_matched = self.flow_snapshot(); + let has_catchall = cases .last() .is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard()); - let mut no_case_matched = self.flow_snapshot(); let mut vis_constraints = vec![]; let mut post_case_snapshots = vec![]; let mut match_predicate; @@ -1585,7 +1586,10 @@ where self.current_match_case = Some(CurrentMatchCase::new(&case.pattern)); self.visit_pattern(&case.pattern); self.current_match_case = None; - no_case_matched = self.flow_snapshot(); + // unlike in [Stmt::If], we don't reset [no_case_matched] + // here because the effects of visiting a pattern is binding + // symbols, and this doesn't occur unless the pattern + // actually matches match_predicate = self.add_pattern_narrowing_constraint( subject_expr, &case.pattern, @@ -1627,6 +1631,7 @@ where if let Some(match_success_guard_failure) = match_success_guard_failure { self.flow_merge(match_success_guard_failure); } + no_case_matched = self.flow_snapshot(); } } From 35dabdc5d604c567bb45fd48da5e89c1197d3bc8 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Thu, 17 Apr 2025 19:38:41 -0400 Subject: [PATCH 5/7] cleanup visibility constraint management --- .../src/semantic_index/builder.rs | 41 +++++++++++-------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index aa2b2c74f934a..52eb5c00258fb 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -568,16 +568,23 @@ impl<'db> SemanticIndexBuilder<'db> { id } + /// Constructs a visibility constraint id without recording it + fn visibility_constraint_id( + &mut self, + predicate: Predicate<'db>, + ) -> ScopedVisibilityConstraintId { + let predicate_id = self.current_use_def_map_mut().add_predicate(predicate); + self.current_visibility_constraints_mut() + .add_atom(predicate_id) + } + /// Records a visibility constraint by applying it to all live bindings and declarations. #[must_use = "A visibility constraint must always be negated after it is added"] fn record_visibility_constraint( &mut self, predicate: Predicate<'db>, ) -> ScopedVisibilityConstraintId { - let predicate_id = self.current_use_def_map_mut().add_predicate(predicate); - let id = self - .current_visibility_constraints_mut() - .add_atom(predicate_id); + let id = self.visibility_constraint_id(predicate); self.record_visibility_constraint_id(id); id } @@ -1595,28 +1602,26 @@ where &case.pattern, case.guard.as_deref(), ); - self.record_reachability_constraint(match_predicate); + let vis_constraint_id = self.record_reachability_constraint(match_predicate); let match_success_guard_failure = case.guard.as_ref().map(|guard| { let guard_expr = self.add_standalone_expression(guard); self.visit_expr(guard); - let post_match_success = self.flow_snapshot(); + let post_guard_eval = self.flow_snapshot(); let predicate = Predicate { node: PredicateNode::Expression(guard_expr), is_positive: true, }; self.record_negated_narrowing_constraint(predicate); let match_success_guard_failure = self.flow_snapshot(); - self.flow_restore(post_match_success); + self.flow_restore(post_guard_eval); self.record_narrowing_constraint(predicate); match_success_guard_failure }); + self.record_visibility_constraint_id(vis_constraint_id); + self.visit_body(&case.body); - for id in &vis_constraints { - self.record_negated_visibility_constraint(*id); - } - let vis_constraint_id = self.record_visibility_constraint(match_predicate); vis_constraints.push(vis_constraint_id); post_case_snapshots.push(self.flow_snapshot()); @@ -1624,23 +1629,23 @@ where if i != cases.len() - 1 || !has_catchall { // We need to restore the state after each case, but not after the last // one. The last one will just become the state that we merge the other - // snapshots into. (If there's no catch-all, we'll add an implied one - // below, so this can't be the last case.) + // snapshots into. self.flow_restore(no_case_matched.clone()); self.record_negated_narrowing_constraint(match_predicate); if let Some(match_success_guard_failure) = match_success_guard_failure { self.flow_merge(match_success_guard_failure); + } else { + assert!(case.guard.is_none()); } - no_case_matched = self.flow_snapshot(); + } else { + debug_assert!(match_success_guard_failure.is_none()); + debug_assert!(case.guard.is_none()); } - } - // If there is no final wildcard match case, pretend there is one. This is similar to how - // we add an implicit `else` block in if-elif chains, in case it's not present. - if !has_catchall { for id in &vis_constraints { self.record_negated_visibility_constraint(*id); } + no_case_matched = self.flow_snapshot(); } for post_clause_state in post_case_snapshots { From 4cc3753cb162979c4607ce0517b192bcd083d971 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 17 Apr 2025 18:09:30 -0700 Subject: [PATCH 6/7] re-inline unused method --- .../src/semantic_index/builder.rs | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 52eb5c00258fb..cc7990967d5ce 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -568,23 +568,16 @@ impl<'db> SemanticIndexBuilder<'db> { id } - /// Constructs a visibility constraint id without recording it - fn visibility_constraint_id( - &mut self, - predicate: Predicate<'db>, - ) -> ScopedVisibilityConstraintId { - let predicate_id = self.current_use_def_map_mut().add_predicate(predicate); - self.current_visibility_constraints_mut() - .add_atom(predicate_id) - } - /// Records a visibility constraint by applying it to all live bindings and declarations. #[must_use = "A visibility constraint must always be negated after it is added"] fn record_visibility_constraint( &mut self, predicate: Predicate<'db>, ) -> ScopedVisibilityConstraintId { - let id = self.visibility_constraint_id(predicate); + let predicate_id = self.current_use_def_map_mut().add_predicate(predicate); + let id = self + .current_visibility_constraints_mut() + .add_atom(predicate_id); self.record_visibility_constraint_id(id); id } From f65bcc18c298d6795f31d4cd6bc6fe1d76ad67b7 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 17 Apr 2025 18:12:22 -0700 Subject: [PATCH 7/7] remove unnecessary collection of visibility constraint IDs --- .../red_knot_python_semantic/src/semantic_index/builder.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index cc7990967d5ce..e4c25f48400b5 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -1578,7 +1578,6 @@ where .last() .is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard()); - let mut vis_constraints = vec![]; let mut post_case_snapshots = vec![]; let mut match_predicate; @@ -1615,7 +1614,6 @@ where self.record_visibility_constraint_id(vis_constraint_id); self.visit_body(&case.body); - vis_constraints.push(vis_constraint_id); post_case_snapshots.push(self.flow_snapshot()); @@ -1635,9 +1633,7 @@ where debug_assert!(case.guard.is_none()); } - for id in &vis_constraints { - self.record_negated_visibility_constraint(*id); - } + self.record_negated_visibility_constraint(vis_constraint_id); no_case_matched = self.flow_snapshot(); }