From ceea52a3c6eac8cc21ae461e65570693309b7799 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Thu, 30 Oct 2025 15:15:24 -0400 Subject: [PATCH 1/2] [ty] Use the top materialization of classes for narrowing in class-patterns for `match` statements --- .../resources/mdtest/narrow/match.md | 49 +++++++++++++++++++ .../reachability_constraints.rs | 5 +- crates/ty_python_semantic/src/types/narrow.rs | 8 +-- 3 files changed, 57 insertions(+), 5 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/match.md b/crates/ty_python_semantic/resources/mdtest/narrow/match.md index ee51d50af25aa5..ee95e8336bc8f2 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/match.md @@ -69,6 +69,55 @@ match x: reveal_type(x) # revealed: object ``` +## Class patterns with generic classes + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import assert_never + +class Covariant[T]: + def get(self) -> T: + raise NotImplementedError + +def f(x: Covariant[int]): + match x: + case Covariant(): + reveal_type(x) # revealed: Covariant[int] + case _: + reveal_type(x) # revealed: Never + assert_never(x) +``` + +## Class patterns with generic `@final` classes + +These work the same as non-`@final` classes. + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import assert_never, final + +@final +class Covariant[T]: + def get(self) -> T: + raise NotImplementedError + +def f(x: Covariant[int]): + match x: + case Covariant(): + reveal_type(x) # revealed: Covariant[int] + case _: + reveal_type(x) # revealed: Never + assert_never(x) +``` + ## Value patterns Value patterns are evaluated by equality, which is overridable. Therefore successfully matching on diff --git a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs index af3ef642e3cd5c..1224190209381b 100644 --- a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs @@ -771,8 +771,9 @@ impl ReachabilityConstraints { truthiness } PatternPredicateKind::Class(class_expr, kind) => { - let class_ty = - infer_expression_type(db, *class_expr, TypeContext::default()).to_instance(db); + let class_ty = infer_expression_type(db, *class_expr, TypeContext::default()) + .as_class_literal() + .map(|class| Type::instance(db, class.top_materialization(db))); class_ty.map_or(Truthiness::Ambiguous, |class_ty| { if subject_ty.is_subtype_of(db, class_ty) { diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 736272cb4a94f2..2d4a714fd44a00 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -962,10 +962,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let subject = place_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); - let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module) - .to_instance(self.db)?; + let class = + infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module) + .as_class_literal()? + .top_materialization(self.db); - let ty = ty.negate_if(self.db, !is_positive); + let ty = Type::instance(self.db, class).negate_if(self.db, !is_positive); Some(NarrowingConstraints::from_iter([(place, ty)])) } From 449b8ab5cd8affe2fbff83abc54205fe05e32c96 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Thu, 30 Oct 2025 16:25:52 -0400 Subject: [PATCH 2/2] fix `jax` false positives --- .../mdtest/exhaustiveness_checking.md | 19 ++++++++++++++ .../resources/mdtest/narrow/match.md | 26 +++++++++++++++++++ crates/ty_python_semantic/src/types/narrow.rs | 24 +++++++++++------ 3 files changed, 61 insertions(+), 8 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md b/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md index 72183597509f53..29b267024bbe30 100644 --- a/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md +++ b/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md @@ -182,6 +182,11 @@ def match_non_exhaustive(x: Color): ## `isinstance` checks +```toml +[environment] +python-version = "3.12" +``` + ```py from typing import assert_never @@ -189,6 +194,9 @@ class A: ... class B: ... class C: ... +class GenericClass[T]: + x: T + def if_else_exhaustive(x: A | B | C): if isinstance(x, A): pass @@ -253,6 +261,17 @@ def match_non_exhaustive(x: A | B | C): # this diagnostic is correct: the inferred type of `x` is `B & ~A & ~C` assert_never(x) # error: [type-assertion-failure] + +# Note: no invalid-return-type diagnostic; the `match` is exhaustive +def match_exhaustive_generic[T](obj: GenericClass[T]) -> GenericClass[T]: + match obj: + case GenericClass(x=42): + reveal_type(obj) # revealed: GenericClass[T@match_exhaustive_generic] + return obj + case GenericClass(x=x): + reveal_type(x) # revealed: @Todo(`match` pattern definition types) + reveal_type(obj) # revealed: GenericClass[T@match_exhaustive_generic] + return obj ``` ## `isinstance` checks with generics diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/match.md b/crates/ty_python_semantic/resources/mdtest/narrow/match.md index ee95e8336bc8f2..f0c107851b67b0 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/match.md @@ -118,6 +118,32 @@ def f(x: Covariant[int]): assert_never(x) ``` +## Class patterns where the class pattern does not resolve to a class + +In general this does not allow for narrowing, but we make an exception for `Any`. This is to support +[real ecosystem code](https://github.com/jax-ml/jax/blob/d2ce04b6c3d03ae18b145965b8b8b92e09e8009c/jax/_src/pallas/mosaic_gpu/lowering.py#L3372-L3387) +found in `jax`. + +```py +from typing import Any + +X = Any + +def f(obj: object): + match obj: + case int(): + reveal_type(obj) # revealed: int + case X(): + reveal_type(obj) # revealed: Any & ~int + +def g(obj: object, Y: Any): + match obj: + case int(): + reveal_type(obj) # revealed: int + case Y(): + reveal_type(obj) # revealed: Any & ~int +``` + ## Value patterns Value patterns are evaluated by equality, which is overridable. Therefore successfully matching on diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 2d4a714fd44a00..5b709551f5166e 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -11,8 +11,9 @@ use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::function::KnownFunction; use crate::types::infer::infer_same_file_expression_type; use crate::types::{ - ClassLiteral, ClassType, IntersectionBuilder, KnownClass, SubclassOfInner, SubclassOfType, - Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types, + ClassLiteral, ClassType, IntersectionBuilder, KnownClass, SpecialFormType, SubclassOfInner, + SubclassOfType, Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder, + infer_expression_types, }; use ruff_db::parsed::{ParsedModuleRef, parsed_module}; @@ -962,13 +963,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let subject = place_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); - let class = - infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module) - .as_class_literal()? - .top_materialization(self.db); + let class_type = + infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module); - let ty = Type::instance(self.db, class).negate_if(self.db, !is_positive); - Some(NarrowingConstraints::from_iter([(place, ty)])) + let narrowed_type = match class_type { + Type::ClassLiteral(class) => { + Type::instance(self.db, class.top_materialization(self.db)) + .negate_if(self.db, !is_positive) + } + dynamic @ Type::Dynamic(_) => dynamic, + Type::SpecialForm(SpecialFormType::Any) => Type::any(), + _ => return None, + }; + + Some(NarrowingConstraints::from_iter([(place, narrowed_type)])) } fn evaluate_match_pattern_value(