diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md b/crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md index b07d929a43d5d..b844b0c33c93d 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md @@ -614,8 +614,8 @@ class InvariantWithAny[T: int]: def _(x: object): if isinstance(x, InvariantWithAny): - reveal_type(x) # revealed: Top[InvariantWithAny[Unknown]] - reveal_type(x.a) # revealed: object + reveal_type(x) # revealed: Top[InvariantWithAny[Unknown & int]] + reveal_type(x.a) # revealed: int reveal_type(x.b) # revealed: Any ``` @@ -704,6 +704,214 @@ def _(x: object): reveal_type(x.y) # revealed: tuple[A, object] ``` +When a type parameter has a bound, `isinstance` narrowing should use the bound as the upper limit +for covariant type parameters. When a type parameter has a default, the default is not used; instead +the upper bound (or `object` if unbounded) is used: + +```py +from typing_extensions import Generic, TypeVar, ParamSpec, Callable +from ty_extensions import into_callable + +class UpperBound: ... +class Constraint1: ... +class Constraint2: ... +class UnionBoundElement1: ... +class UnionBoundElement2: ... + +T_contra = TypeVar("T_contra", contravariant=True) +T_contra_bound = TypeVar("T_contra_bound", bound=UpperBound, contravariant=True) +T_contra_constrained = TypeVar("T_contra_constrained", Constraint1, Constraint2, contravariant=True) +T_contra_defaulted = TypeVar("T_contra_defaulted", default=None, contravariant=True) +T_contra_defaulted_and_bound = TypeVar( + "T_contra_defaulted_and_bound", default=UnionBoundElement1, bound=UnionBoundElement1 | UnionBoundElement2, contravariant=True +) +T_contra_defaulted_and_constrained = TypeVar( + "T_contra_defaulted_and_constrained", Constraint1, Constraint2, default=Constraint1, contravariant=True +) + +class Contravariant( + Generic[ + T_contra, + T_contra_bound, + T_contra_constrained, + T_contra_defaulted, + T_contra_defaulted_and_bound, + T_contra_defaulted_and_constrained, + ] +): + def method( + self, + t_contra: T_contra, + t_contra_bound: T_contra_bound, + t_contra_constrained: T_contra_constrained, + t_contra_defaulted: T_contra_defaulted, + t_contra_defaulted_and_bound: T_contra_defaulted_and_bound, + t_contra_defaulted_and_constrained: T_contra_defaulted_and_constrained, + ): ... + +def test_contravariant_narrowing(x: object): + if isinstance(x, Contravariant): + reveal_type(x) # revealed: Contravariant[Never, Never, Never, Never, Never, Never] + # revealed: bound method Contravariant[Never, Never, Never, Never, Never, Never].method(t_contra: Never, t_contra_bound: Never, t_contra_constrained: Never, t_contra_defaulted: Never, t_contra_defaulted_and_bound: Never, t_contra_defaulted_and_constrained: Never) -> Unknown + reveal_type(x.method) + +T_co = TypeVar("T_co", covariant=True) +T_co_bound = TypeVar("T_co_bound", bound=UpperBound, covariant=True) +T_co_defaulted = TypeVar("T_co_defaulted", default=None, covariant=True) +T_co_defaulted_and_bound = TypeVar( + "T_co_defaulted_and_bound", default=UnionBoundElement1, bound=UnionBoundElement1 | UnionBoundElement2, covariant=True +) + +class Covariant1(Generic[T_co, T_co_bound, T_co_defaulted, T_co_defaulted_and_bound]): + def t_co(self) -> T_co: + raise NotImplementedError + + def t_co_bound(self) -> T_co_bound: + raise NotImplementedError + + def t_co_defaulted(self) -> T_co_defaulted: + raise NotImplementedError + + def t_co_defaulted_and_bound(self) -> T_co_defaulted_and_bound: + raise NotImplementedError + +T_co_constrained = TypeVar("T_co_constrained", Constraint1, Constraint2, covariant=True) +T_co_defaulted_and_constrained = TypeVar( + "T_co_defaulted_and_constrained", Constraint1, Constraint2, default=Constraint1, covariant=True +) + +class Covariant2(Generic[T_co_constrained, T_co_defaulted_and_constrained]): + def t_co_constrained(self) -> T_co_constrained: + raise NotImplementedError + + def t_co_defaulted_and_constrained(self) -> T_co_defaulted_and_constrained: + raise NotImplementedError + +def test_covariant_narrowing(x: object): + if isinstance(x, Covariant1): + reveal_type(x) # revealed: Covariant1[object, UpperBound, object, UnionBoundElement1 | UnionBoundElement2] + reveal_type(x.t_co_bound()) # revealed: UpperBound + reveal_type(x.t_co_defaulted()) # revealed: object + reveal_type(x.t_co_defaulted_and_bound()) # revealed: UnionBoundElement1 | UnionBoundElement2 + + if isinstance(x, Covariant2): + # TODO: solving a constrained TypeVar to anything except `Unknown` or one of its constraints is invalid. + # A more accurate revealed type here would be + # + # ( + # Covariant2[Constraint1, Constraint1] + # | Covariant2[Constraint1, Constraint2] + # | Covariant2[Constraint2, Constraint1] + # | Covariant2[Constraint2, Constraint2] + # ) + # + # revealed: Covariant2[Constraint1 | Constraint2, Constraint1 | Constraint2] + reveal_type(x) + + reveal_type(x.t_co_constrained()) # revealed: Constraint1 | Constraint2 + reveal_type(x.t_co_defaulted_and_constrained()) # revealed: Constraint1 | Constraint2 + +T = TypeVar("T") +T_bound = TypeVar("T_bound", bound=UpperBound) +T_constrained = TypeVar("T_constrained", Constraint1, Constraint2) +T_defaulted = TypeVar("T_defaulted", default=None) +T_defaulted_and_bound = TypeVar( + "T_defaulted_and_bound", default=UnionBoundElement1, bound=UnionBoundElement1 | UnionBoundElement2 +) +T_defaulted_and_constrained = TypeVar("T_defaulted_and_constrained", Constraint1, Constraint2, default=Constraint1) +P = ParamSpec("P") +P_defaulted = ParamSpec("P_defaulted", default=[int, str]) + +class Invariant1(Generic[T, T_bound, P, T_defaulted, T_defaulted_and_bound, P_defaulted]): + t: T + t_bound: T_bound + t_defaulted: T_defaulted + t_defaulted_and_bound: T_defaulted_and_bound + callable_attr: Callable[P, None] + defaulted_callable_attr: Callable[P_defaulted, None] + + def method(self, *args: P.args, **kwargs: P.kwargs) -> None: ... + +class Invariant2(Generic[T_constrained, T_defaulted_and_constrained]): + t_constrained: T_constrained + t_defaulted_and_constrained: T_defaulted_and_constrained + +def test_invariant_narrowing_from_object(obj: object): + if isinstance(obj, Invariant1): + # revealed: Top[Invariant1[Unknown, Unknown & UpperBound, Top[(...)], Unknown, (Unknown & UnionBoundElement1) | (Unknown & UnionBoundElement2), Top[(...)]]] + reveal_type(obj) + reveal_type(obj.t) # revealed: object + reveal_type(obj.t_bound) # revealed: UpperBound + reveal_type(obj.t_defaulted) # revealed: object + reveal_type(obj.t_defaulted_and_bound) # revealed: UnionBoundElement1 | UnionBoundElement2 + + reveal_type(obj.callable_attr) # revealed: Top[(...) -> None] + reveal_type(obj.defaulted_callable_attr) # revealed: Top[(...) -> None] + + # TODO: should probably be `(*args: Never, **kwargs: Never) -> None`? + reveal_type(into_callable(obj.method)) # revealed: (*args: object, **kwargs: object) -> None + + if isinstance(obj, Invariant2): + # TODO: solving a constrained TypeVar to anything except `Unknown` or one of its constraints is invalid. + # A more accurate revealed type here would be + # + # ( + # Invariant2[Constraint1, Constraint1] + # | Invariant2[Constraint1, Constraint2] + # | Invariant2[Constraint2, Constraint1] + # | Invariant2[Constraint2, Constraint2] + # ) + # + # revealed: Top[Invariant2[(Unknown & Constraint1) | (Unknown & Constraint2), (Unknown & Constraint1) | (Unknown & Constraint2)]] + reveal_type(obj) + + reveal_type(obj.t_constrained) # revealed: Constraint1 | Constraint2 + reveal_type(obj.t_defaulted_and_constrained) # revealed: Constraint1 | Constraint2 + +def test_invariant_narrowing_from_unspecialized_instance( + invariant_1_unspecialized: Invariant1, invariant_2_unspecialized: Invariant2 +): + if isinstance(invariant_1_unspecialized, Invariant1): + # revealed: Invariant1[Unknown, Unknown, (...), None, UnionBoundElement1, (int, str, /)] + reveal_type(invariant_1_unspecialized) + reveal_type(invariant_1_unspecialized.t) # revealed: Unknown + reveal_type(invariant_1_unspecialized.t_bound) # revealed: Unknown + reveal_type(invariant_1_unspecialized.t_defaulted) # revealed: None + reveal_type(invariant_1_unspecialized.t_defaulted_and_bound) # revealed: UnionBoundElement1 + + reveal_type(invariant_1_unspecialized.callable_attr) # revealed: (...) -> None + reveal_type(invariant_1_unspecialized.defaulted_callable_attr) # revealed: (int, str, /) -> None + + reveal_type(into_callable(invariant_1_unspecialized.method)) # revealed: (...) -> None + + if isinstance(invariant_2_unspecialized, Invariant2): + reveal_type(invariant_2_unspecialized) # revealed: Invariant2[Unknown, Constraint1] + reveal_type(invariant_2_unspecialized.t_constrained) # revealed: Unknown + reveal_type(invariant_2_unspecialized.t_defaulted_and_constrained) # revealed: Constraint1 + +def test_invariant_narrowing_from_specialized_instance( + invariant_1_specialized: Invariant1[int, UpperBound, [int, str], int, UnionBoundElement1, [int, str]], + invariant_2_specialized: Invariant2[Constraint1, Constraint2], +): + if isinstance(invariant_1_specialized, Invariant1): + # revealed: Invariant1[int, UpperBound, (int, str, /), int, UnionBoundElement1, (int, str, /)] + reveal_type(invariant_1_specialized) + reveal_type(invariant_1_specialized.t) # revealed: int + reveal_type(invariant_1_specialized.t_bound) # revealed: UpperBound + reveal_type(invariant_1_specialized.t_defaulted) # revealed: int + reveal_type(invariant_1_specialized.t_defaulted_and_bound) # revealed: UnionBoundElement1 + + reveal_type(invariant_1_specialized.callable_attr) # revealed: (int, str, /) -> None + reveal_type(invariant_1_specialized.defaulted_callable_attr) # revealed: (int, str, /) -> None + + reveal_type(into_callable(invariant_1_specialized.method)) # revealed: (int, str, /) -> None + + if isinstance(invariant_2_specialized, Invariant2): + reveal_type(invariant_2_specialized) # revealed: Invariant2[Constraint1, Constraint2] + reveal_type(invariant_2_specialized.t_constrained) # revealed: Constraint1 + reveal_type(invariant_2_specialized.t_defaulted_and_constrained) # revealed: Constraint2 +``` + ## Narrowing with TypedDict unions Narrowing unions of `int` and multiple TypedDicts using `isinstance(x, dict)` should not panic diff --git a/crates/ty_python_semantic/resources/mdtest/overloads.md b/crates/ty_python_semantic/resources/mdtest/overloads.md index c38d640f01b73..e859a3e938675 100644 --- a/crates/ty_python_semantic/resources/mdtest/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/overloads.md @@ -897,3 +897,43 @@ def baz(x, y, z=None) -> bytes | list[str]: # revealed: Overload[(x, y) -> bytes, (x, y, z) -> list[str]] reveal_type(baz) ``` + +### Overload solving in cases involving type variables bound to gradual types + +```toml +[environment] +python-version = "3.12" +``` + +`library.pyi`: + +```pyi +from typing import Any, Never, overload + +class Foo: ... + +class Bar[T: Any]: + def get1(self) -> T: ... + +@overload +def foo(obj: Bar[Foo]) -> Never: ... +@overload +def foo(obj) -> Foo: ... +``` + +`app.py`: + +```py +from library import Bar, foo, Foo + +def test(obj: Bar): + # Not all materializations of `Bar` (== `Bar[Unknown]`) are assignable to + # the parameter type of the first overload (`Bar[Foo]`), so neither overload + # can be eliminated by step 5 of the overload evaluation algorithm. The return + # types of the two overloads are not equivalent, so we must assume a return type + # of `Any`/`Unknown` and stop, according to step 5. + reveal_type(foo(obj)) # revealed: Unknown + +def test2(obj: Bar[Foo]): + reveal_type(foo(obj)) # revealed: Never +``` diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/materialization.md b/crates/ty_python_semantic/resources/mdtest/type_properties/materialization.md index 27c41fc27461d..44e80c4fc7fef 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/materialization.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/materialization.md @@ -581,6 +581,130 @@ def contravariant(top: Top[ContravariantCallable], bottom: Bottom[ContravariantC reveal_type(bottom) # revealed: (GenericContravariant[Never], /) -> None ``` +## Bounded TypeVars + +When materializing a generic class, bounded TypeVars should have their materialized type capped by +the upper bound. For example, `Unknown` materializes to `object` for a covariant TypeVar, but if the +TypeVar has `bound=int`, the result should be `int` instead: + +```toml +[environment] +python-version = "3.13" +``` + +```py +from typing_extensions import Any, Generic, TypeVar, Never, ParamSpec, Callable +from ty_extensions import Bottom, Top, static_assert, is_equivalent_to + +T_bounded_co = TypeVar("T_bounded_co", bound=int, covariant=True) +T_default_co = TypeVar("T_default_co", default=None, covariant=True) +T_bounded_and_defaulted_co = TypeVar("T_bounded_and_defaulted_co", default=None, bound=int | None, covariant=True) + +class BoundedAndDefaulted(Generic[T_bounded_co, T_default_co, T_bounded_and_defaulted_co]): + def get_bounded(self) -> T_bounded_co: + raise NotImplementedError + + def get_defaulted(self) -> T_default_co: + raise NotImplementedError + + def get_bounded_and_defaulted(self) -> T_bounded_and_defaulted_co: + raise NotImplementedError + +# `Top[BoundedAndDefaulted]` translates to "the top materialization of the default specialization of the `BoundedAndDefaulted` type". +# I.e. `Top[BoundedAndDefaulted]` -> `Top[BoundedAndDefaulted[Unknown, None]]` -> `BoundedAndDefaulted[int, None]`. +# +# This is consistent with how an unparameterized class always means "the class with the default specialization applied" +# in type expressions, but note that it means that `Top[BoundedAndDefaulted]` here is therefore a different type +# to the "true top materialization" of `BoundedAndDefaulted` that we would intersect a type with for `isinstance()` +# narrowing. That "true top materialization" is tested elsewhere, in `isinstance.md`. +def f(x: Top[BoundedAndDefaulted]): + reveal_type(x) # revealed: BoundedAndDefaulted[int, None, None] + reveal_type(x.get_bounded()) # revealed: int + reveal_type(x.get_defaulted()) # revealed: None + reveal_type(x.get_bounded_and_defaulted()) # revealed: None + +def g(x: Top[BoundedAndDefaulted[Any, Any, Any]]): + reveal_type(x) # revealed: BoundedAndDefaulted[int, object, int | None] + reveal_type(x.get_bounded()) # revealed: int + reveal_type(x.get_defaulted()) # revealed: object + reveal_type(x.get_bounded_and_defaulted()) # revealed: int | None + +def h(x: Top[BoundedAndDefaulted[int, int, int]]): + reveal_type(x) # revealed: BoundedAndDefaulted[int, int, int] + reveal_type(x.get_bounded()) # revealed: int + reveal_type(x.get_defaulted()) # revealed: int + reveal_type(x.get_bounded_and_defaulted()) # revealed: int + +T_bounded_invariant = TypeVar("T_bounded_invariant", bound=int) +T_defaulted_invariant = TypeVar("T_defaulted_invariant", default=None) +T_bounded_and_defaulted_invariant = TypeVar("T_bounded_and_defaulted_invariant", default=None, bound=int | None) +P = ParamSpec("P") +P_defaulted = ParamSpec("P_defaulted", default=[int, str]) + +class BoundedAndDefaultedInvariant( + Generic[T_bounded_invariant, P, T_defaulted_invariant, T_bounded_and_defaulted_invariant, P_defaulted] +): + t_bounded: T_bounded_invariant + t_defaulted: T_defaulted_invariant + t_bounded_and_defaulted: T_bounded_and_defaulted_invariant + callable_attr: Callable[P, None] + callable_attr_defaulted: Callable[P_defaulted, None] + +def i(x: Top[BoundedAndDefaultedInvariant]): + # revealed: Top[BoundedAndDefaultedInvariant[Unknown & int, Top[(...)], None, None, (int, str, /)]] + reveal_type(x) + reveal_type(x.t_bounded) # revealed: int + reveal_type(x.t_defaulted) # revealed: None + reveal_type(x.t_bounded_and_defaulted) # revealed: None + reveal_type(x.callable_attr) # revealed: Top[(...) -> None] + reveal_type(x.callable_attr_defaulted) # revealed: (int, str, /) -> None + +def j(x: Top[BoundedAndDefaultedInvariant[Any, ..., Any, Any, ...]]): + # revealed: Top[BoundedAndDefaultedInvariant[Any & int, Top[(...)], Any, (Any & int) | (Any & None), Top[(...)]]] + reveal_type(x) + reveal_type(x.t_bounded) # revealed: int + reveal_type(x.t_defaulted) # revealed: object + reveal_type(x.t_bounded_and_defaulted) # revealed: int | None + reveal_type(x.callable_attr) # revealed: Top[(...) -> None] + reveal_type(x.callable_attr_defaulted) # revealed: Top[(...) -> None] + +def k(x: Top[BoundedAndDefaultedInvariant[bool, [int, str], bool, bool, [int, str]]]): + # revealed: BoundedAndDefaultedInvariant[bool, (int, str, /), bool, bool, (int, str, /)] + reveal_type(x) + reveal_type(x.t_bounded) # revealed: bool + reveal_type(x.t_defaulted) # revealed: bool + reveal_type(x.t_bounded_and_defaulted) # revealed: bool + reveal_type(x.callable_attr) # revealed: (int, str, /) -> None + reveal_type(x.callable_attr_defaulted) # revealed: (int, str, /) -> None + +T_bounded_contra = TypeVar("T_bounded_contra", bound=int, contravariant=True) +T_defaulted_contra = TypeVar("T_defaulted_contra", default=None, contravariant=True) +T_bounded_and_defaulted_contra = TypeVar("T_bounded_and_defaulted_contra", default=None, bound=int | None, contravariant=True) + +class BoundedAndDefaultedContra(Generic[T_bounded_contra, T_defaulted_contra, T_bounded_and_defaulted_contra]): + def method( + self, + t_bounded: T_bounded_contra, + t_defaulted: T_defaulted_contra, + t_bounded_and_defaulted: T_bounded_and_defaulted_contra, + ): ... + +def l(x: Top[BoundedAndDefaultedContra]): + reveal_type(x) # revealed: BoundedAndDefaultedContra[Never, None, None] + # revealed: bound method BoundedAndDefaultedContra[Never, None, None].method(t_bounded: Never, t_defaulted: None, t_bounded_and_defaulted: None) -> Unknown + reveal_type(x.method) + +def m(x: Top[BoundedAndDefaultedContra[Any, Any, Any]]): + reveal_type(x) # revealed: BoundedAndDefaultedContra[Never, Never, Never] + # revealed: bound method BoundedAndDefaultedContra[Never, Never, Never].method(t_bounded: Never, t_defaulted: Never, t_bounded_and_defaulted: Never) -> Unknown + reveal_type(x.method) + +def n(x: Top[BoundedAndDefaultedContra[int, int, int]]): + reveal_type(x) # revealed: BoundedAndDefaultedContra[int, int, int] + # revealed: bound method BoundedAndDefaultedContra[int, int, int].method(t_bounded: int, t_defaulted: int, t_bounded_and_defaulted: int) -> Unknown + reveal_type(x.method) +``` + ## Invalid use `Top[]` and `Bottom[]` are special forms that take a single argument. diff --git a/crates/ty_python_semantic/src/types/class/known.rs b/crates/ty_python_semantic/src/types/class/known.rs index e071628535b64..814aea0742446 100644 --- a/crates/ty_python_semantic/src/types/class/known.rs +++ b/crates/ty_python_semantic/src/types/class/known.rs @@ -936,7 +936,7 @@ impl KnownClass { } class_literal - .apply_specialization(db, |_| generic_context.specialize(db, specialization)) + .apply_specialization(db, |_| generic_context.specialize(db, specialization, None)) } let class_literal = self.to_class_literal(db).as_class_literal()?.as_static()?; diff --git a/crates/ty_python_semantic/src/types/class/static_literal.rs b/crates/ty_python_semantic/src/types/class/static_literal.rs index fc9b526808c18..34d3a30a91fa8 100644 --- a/crates/ty_python_semantic/src/types/class/static_literal.rs +++ b/crates/ty_python_semantic/src/types/class/static_literal.rs @@ -381,7 +381,7 @@ impl<'db> StaticClassLiteral<'db> { pub(crate) fn top_materialization(self, db: &'db dyn Db) -> ClassType<'db> { self.apply_specialization(db, |generic_context| { generic_context - .default_specialization(db, self.known(db)) + .unknown_specialization(db, self.known(db)) .materialize_impl( db, MaterializationKind::Top, @@ -404,7 +404,7 @@ impl<'db> StaticClassLiteral<'db> { /// maps each of the class's typevars to `Unknown`. pub(crate) fn unknown_specialization(self, db: &'db dyn Db) -> ClassType<'db> { self.apply_specialization(db, |generic_context| { - generic_context.unknown_specialization(db) + generic_context.unknown_specialization(db, self.known(db)) }) } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index cf8cf1ddb9bd2..5706a1bf1c7dd 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -779,16 +779,32 @@ impl<'db> GenericContext<'db> { /// Returns a specialization of this generic context where each typevar is mapped to itself. pub(crate) fn identity_specialization(self, db: &'db dyn Db) -> Specialization<'db> { let types: Vec = self.variables(db).map(Type::TypeVar).collect(); - self.specialize(db, types) + self.specialize(db, types, None) } - pub(crate) fn unknown_specialization(self, db: &'db dyn Db) -> Specialization<'db> { - match self.len(db) { - 0 => self.specialize(db, &[]), - 1 => self.specialize(db, &[Type::unknown(); 1]), - 2 => self.specialize(db, &[Type::unknown(); 2]), - len => self.specialize(db, vec![Type::unknown(); len]), - } + pub(crate) fn unknown_specialization( + self, + db: &'db dyn Db, + known_class: Option, + ) -> Specialization<'db> { + let tuple = if known_class == Some(KnownClass::Tuple) { + Some(TupleType::homogeneous(db, Type::unknown())) + } else { + None + }; + + let types: Vec> = self + .variables(db) + .map(|tvar| { + if tvar.is_paramspec(db) { + Type::paramspec_value_callable(db, Parameters::unknown()) + } else { + Type::unknown() + } + }) + .collect(); + + self.specialize(db, types, tuple) } pub(crate) fn is_subset_of(self, db: &'db dyn Db, other: GenericContext<'db>) -> bool { @@ -827,7 +843,12 @@ impl<'db> GenericContext<'db> { /// otherwise, you will be left with a partial specialization. (Use /// [`specialize_recursive`](Self::specialize_recursive) if your types might mention typevars /// in this generic context.) - pub(crate) fn specialize<'t, T>(self, db: &'db dyn Db, types: T) -> Specialization<'db> + pub(crate) fn specialize<'t, T>( + self, + db: &'db dyn Db, + types: T, + tuple: Option>, + ) -> Specialization<'db> where T: Into]>>, 'db: 't, @@ -835,7 +856,7 @@ impl<'db> GenericContext<'db> { let types = types.into(); assert_eq!(self.len(db), types.len()); - Specialization::new(db, self, types, None, None) + Specialization::new(db, self, types, None, tuple) } /// Creates a specialization of this generic context. Panics if the length of `types` does not @@ -1230,19 +1251,19 @@ impl<'db> Specialization<'db> { if self.materialization_kind(db).is_some() { return self; } + let mut has_dynamic_invariant_typevar = false; + let types: Box<[_]> = self .generic_context(db) .variables(db) .zip(self.types(db)) .map(|(bound_typevar, vartype)| { - match bound_typevar.variance(db) { - TypeVarVariance::Bivariant => { - // With bivariance, all specializations are subtypes of each other, - // so any materialization is acceptable. - vartype.materialize(db, MaterializationKind::Top, visitor) - } - TypeVarVariance::Covariant => { + let variance = bound_typevar.variance(db); + let mut materialized = match variance { + // With bivariance, all specializations are subtypes of each other, + // so any materialization is acceptable. + TypeVarVariance::Bivariant | TypeVarVariance::Covariant => { vartype.materialize(db, materialization_kind, visitor) } TypeVarVariance::Contravariant => { @@ -1256,9 +1277,27 @@ impl<'db> Specialization<'db> { } *vartype } + }; + + if matches!( + (materialization_kind, variance), + ( + MaterializationKind::Top, + TypeVarVariance::Covariant + | TypeVarVariance::Bivariant + | TypeVarVariance::Invariant + ) | (MaterializationKind::Bottom, TypeVarVariance::Contravariant) + ) && let Some(bounds) = bound_typevar.typevar(db).bound_or_constraints(db) + { + materialized = bounds + .materialize_impl(db, MaterializationKind::Top, visitor) + .intersect_with(db, materialized); } + + materialized }) .collect(); + let tuple_inner = self.tuple_inner(db).and_then(|tuple| { // Tuples are immutable, so tuple element types are always in covariant position. tuple.apply_type_mapping_impl( @@ -1268,11 +1307,13 @@ impl<'db> Specialization<'db> { visitor, ) }); + let new_materialization_kind = if has_dynamic_invariant_typevar { Some(materialization_kind) } else { None }; + Specialization::new( db, self.generic_context(db), @@ -1359,6 +1400,7 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { source_materialization_kind, *target_type, target_materialization_kind, + bound_typevar.typevar(db), ), TypeVarVariance::Covariant => { self.check_type_pair(db, *source_type, *target_type) @@ -1382,6 +1424,7 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { source_materialization: Option, target_type: Type<'db>, target_materialization: Option, + tvar: TypeVarInstance<'db>, ) -> ConstraintSet<'db, 'c> { match ( source_materialization, @@ -1422,13 +1465,32 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { TypeRelation::Subtyping | TypeRelation::Redundancy { .. } | TypeRelation::SubtypingAssuming, - ) => self.check_subtyping_in_invariant_position( - db, - source_type, - MaterializationKind::Top, - target_type, - target_mat, - ), + ) => { + // Ensure that all possible specializations of a class are considered + // subtypes of the top-materialization specialization of that class + // by intersecting `source_type` with the bounds/constraints of the + // type variable, if they exist. + let effective_source_type = tvar + .bound_or_constraints(db) + .map(|bound_or_constraints| { + bound_or_constraints + .materialize_impl( + db, + MaterializationKind::Top, + &ApplyTypeMappingVisitor::default(), + ) + .intersect_with(db, source_type) + }) + .unwrap_or(source_type); + + self.check_subtyping_in_invariant_position( + db, + effective_source_type, + MaterializationKind::Top, + target_type, + target_mat, + ) + } ( Some(source_mat), None, diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index b3f0484dc79e7..b32b4d6c1f33b 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -1090,6 +1090,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { generic_context.len(db), ) .collect::>(), + None, ), ); return if in_type_expression { diff --git a/crates/ty_python_semantic/src/types/typevar.rs b/crates/ty_python_semantic/src/types/typevar.rs index 5f3e9fee7d727..2f8bde7c5a9f2 100644 --- a/crates/ty_python_semantic/src/types/typevar.rs +++ b/crates/ty_python_semantic/src/types/typevar.rs @@ -12,11 +12,11 @@ use crate::{ semantic_index, }, types::{ - ApplySpecialization, ApplyTypeMappingVisitor, CycleDetector, DynamicType, KnownClass, - KnownInstanceType, MaterializationKind, Parameter, Parameters, Type, TypeAliasType, - TypeContext, TypeMapping, TypeVarVariance, UnionBuilder, UnionType, any_over_type, - binding_type, definition_expression_type, tuple::Tuple, variance::VarianceInferable, - visitor, + ApplySpecialization, ApplyTypeMappingVisitor, CycleDetector, DynamicType, IntersectionType, + KnownClass, KnownInstanceType, MaterializationKind, Parameter, Parameters, Type, + TypeAliasType, TypeContext, TypeMapping, TypeVarVariance, UnionBuilder, UnionType, + any_over_type, binding_type, definition_expression_type, tuple::Tuple, + variance::VarianceInferable, visitor, }, }; @@ -1365,7 +1365,7 @@ pub(super) fn walk_type_var_bounds<'db, V: visitor::TypeVisitor<'db> + ?Sized>( } impl<'db> TypeVarBoundOrConstraints<'db> { - fn materialize_impl( + pub(super) fn materialize_impl( self, db: &'db dyn Db, materialization_kind: MaterializationKind, @@ -1384,6 +1384,25 @@ impl<'db> TypeVarBoundOrConstraints<'db> { } } } + + /// If `self` represents an upper bound, returns the intersection of the upper bound and `other`. + /// If `self` represents a set of constraints, returns the intersection of ` & other`. + pub(super) fn intersect_with(self, db: &'db dyn Db, other: Type<'db>) -> Type<'db> { + match self { + TypeVarBoundOrConstraints::UpperBound(bound) => { + IntersectionType::from_two_elements(db, other, bound) + } + // Conceptually the same as `IntersectionType::from_two_elements(db, constraints.as_type(db), other)`, + // but this gets us there a more direct way in the `UnionBuilder` + TypeVarBoundOrConstraints::Constraints(constraints) => UnionType::from_elements( + db, + constraints + .elements(db) + .iter() + .map(|&constraint| IntersectionType::from_two_elements(db, other, constraint)), + ), + } + } } /// A [`CycleDetector`] that is used in `TypeVarInstance::default_type`.