Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions crates/red_knot_python_semantic/resources/mdtest/narrow/match.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ match x:
case A():
reveal_type(x) # revealed: A
case B():
# TODO could be `B & ~A`
reveal_type(x) # revealed: B
reveal_type(x) # revealed: B & ~A

reveal_type(x) # revealed: object
```
Expand Down Expand Up @@ -88,7 +87,7 @@ match x:
case 6.0:
reveal_type(x) # revealed: float
case 1j:
reveal_type(x) # revealed: complex
reveal_type(x) # revealed: complex & ~float
case b"foo":
reveal_type(x) # revealed: Literal[b"foo"]

Expand Down Expand Up @@ -134,11 +133,11 @@ match x:
case "foo" | 42 | None:
reveal_type(x) # revealed: Literal["foo", 42] | None
case "foo" | tuple():
reveal_type(x) # revealed: Literal["foo"] | tuple
reveal_type(x) # revealed: tuple
case True | False:
reveal_type(x) # revealed: bool
case 3.14 | 2.718 | 1.414:
reveal_type(x) # revealed: float
reveal_type(x) # revealed: float & ~tuple

reveal_type(x) # revealed: object
```
Expand All @@ -165,3 +164,49 @@ match x:

reveal_type(x) # revealed: object
```

## Narrowing due to guard

```py
def get_object() -> object:
return object()

x = get_object()

reveal_type(x) # revealed: object

match x:
case str() | float() if type(x) is str:
reveal_type(x) # revealed: str
case "foo" | 42 | None if isinstance(x, int):
reveal_type(x) # revealed: Literal[42]
case False if x:
reveal_type(x) # revealed: Never
case "foo" if x := "bar":
reveal_type(x) # revealed: Literal["bar"]

reveal_type(x) # revealed: object
```

## Guard and reveal_type in guard

```py
def get_object() -> object:
return object()

x = get_object()

reveal_type(x) # revealed: object

match x:
case str() | float() if type(x) is str and reveal_type(x): # revealed: str
pass
case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: Literal[42]
pass
case False if x and reveal_type(x): # revealed: Never
pass
case "foo" if (x := "bar") and reveal_type(x): # revealed: Literal["bar"]
pass

reveal_type(x) # revealed: object
```
78 changes: 50 additions & 28 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1572,54 +1572,76 @@ where
return;
}

let after_subject = self.flow_snapshot();
let mut vis_constraints = vec![];
let mut no_case_matched = self.flow_snapshot();

let has_catchall = cases
.last()
.is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard());

let mut post_case_snapshots = vec![];
for (i, case) in cases.iter().enumerate() {
if i != 0 {
post_case_snapshots.push(self.flow_snapshot());
self.flow_restore(after_subject.clone());
}
let mut match_predicate;

for (i, case) in cases.iter().enumerate() {
self.current_match_case = Some(CurrentMatchCase::new(&case.pattern));
self.visit_pattern(&case.pattern);
self.current_match_case = None;
let predicate = self.add_pattern_narrowing_constraint(
// unlike in [Stmt::If], we don't reset [no_case_matched]
// here because the effects of visiting a pattern is binding
// symbols, and this doesn't occur unless the pattern
// actually matches
match_predicate = self.add_pattern_narrowing_constraint(
subject_expr,
&case.pattern,
case.guard.as_deref(),
);
self.record_reachability_constraint(predicate);
if let Some(expr) = &case.guard {
self.visit_expr(expr);
}
let vis_constraint_id = self.record_reachability_constraint(match_predicate);

let match_success_guard_failure = case.guard.as_ref().map(|guard| {
let guard_expr = self.add_standalone_expression(guard);
self.visit_expr(guard);
let post_guard_eval = self.flow_snapshot();
let predicate = Predicate {
node: PredicateNode::Expression(guard_expr),
is_positive: true,
};
self.record_negated_narrowing_constraint(predicate);
let match_success_guard_failure = self.flow_snapshot();
self.flow_restore(post_guard_eval);
self.record_narrowing_constraint(predicate);
match_success_guard_failure
});

self.record_visibility_constraint_id(vis_constraint_id);

self.visit_body(&case.body);
for id in &vis_constraints {
self.record_negated_visibility_constraint(*id);
}
let vis_constraint_id = self.record_visibility_constraint(predicate);
vis_constraints.push(vis_constraint_id);
}

// If there is no final wildcard match case, pretend there is one. This is similar to how
// we add an implicit `else` block in if-elif chains, in case it's not present.
if !cases
.last()
.is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard())
{
post_case_snapshots.push(self.flow_snapshot());
self.flow_restore(after_subject.clone());

for id in &vis_constraints {
self.record_negated_visibility_constraint(*id);
if i != cases.len() - 1 || !has_catchall {
// We need to restore the state after each case, but not after the last
// one. The last one will just become the state that we merge the other
// snapshots into.
self.flow_restore(no_case_matched.clone());
self.record_negated_narrowing_constraint(match_predicate);
if let Some(match_success_guard_failure) = match_success_guard_failure {
self.flow_merge(match_success_guard_failure);
} else {
assert!(case.guard.is_none());
}
} else {
debug_assert!(match_success_guard_failure.is_none());
debug_assert!(case.guard.is_none());
}

self.record_negated_visibility_constraint(vis_constraint_id);
no_case_matched = self.flow_snapshot();
}

for post_clause_state in post_case_snapshots {
self.flow_merge(post_clause_state);
}

self.simplify_visibility_constraints(after_subject);
self.simplify_visibility_constraints(no_case_matched);
}
ast::Stmt::Try(ast::StmtTry {
body,
Expand Down
33 changes: 30 additions & 3 deletions crates/red_knot_python_semantic/src/types/narrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ pub(crate) fn infer_narrowing_constraint<'db>(
all_negative_narrowing_constraints_for_expression(db, expression)
}
}
PredicateNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern),
PredicateNode::Pattern(pattern) => {
if predicate.is_positive {
all_narrowing_constraints_for_pattern(db, pattern)
} else {
all_negative_narrowing_constraints_for_pattern(db, pattern)
}
}
PredicateNode::StarImportPlaceholder(_) => return None,
};
if let Some(constraints) = constraints {
Expand Down Expand Up @@ -95,6 +101,15 @@ fn all_negative_narrowing_constraints_for_expression<'db>(
NarrowingConstraintsBuilder::new(db, PredicateNode::Expression(expression), false).finish()
}

#[allow(clippy::ref_option)]
#[salsa::tracked(return_ref)]
fn all_negative_narrowing_constraints_for_pattern<'db>(
db: &'db dyn Db,
pattern: PatternPredicate<'db>,
) -> Option<NarrowingConstraints<'db>> {
NarrowingConstraintsBuilder::new(db, PredicateNode::Pattern(pattern), false).finish()
}

#[allow(clippy::ref_option)]
fn constraints_for_expression_cycle_recover<'db>(
_db: &'db dyn Db,
Expand Down Expand Up @@ -217,6 +232,12 @@ fn merge_constraints_or<'db>(
}
}

fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, yes: bool) {
for (_symbol, ty) in constraints.iter_mut() {
*ty = ty.negate_if(db, yes);
}
}

struct NarrowingConstraintsBuilder<'db> {
db: &'db dyn Db,
predicate: PredicateNode<'db>,
Expand All @@ -237,7 +258,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
PredicateNode::Expression(expression) => {
self.evaluate_expression_predicate(expression, self.is_positive)
}
PredicateNode::Pattern(pattern) => self.evaluate_pattern_predicate(pattern),
PredicateNode::Pattern(pattern) => {
self.evaluate_pattern_predicate(pattern, self.is_positive)
}
PredicateNode::StarImportPlaceholder(_) => return None,
};
if let Some(mut constraints) = constraints {
Expand Down Expand Up @@ -300,10 +323,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
fn evaluate_pattern_predicate(
&mut self,
pattern: PatternPredicate<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let subject = pattern.subject(self.db);

self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject)
.map(|mut constraints| {
negate_if(&mut constraints, self.db, !is_positive);
constraints
})
}

fn symbols(&self) -> Arc<SymbolTable> {
Expand Down
Loading