diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index 3cdebe848e5619..9c67fd6cba18a4 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -481,22 +481,99 @@ reveal_type(g(f("a"))) # revealed: tuple[Literal["a"], int] | None ## Passing generic functions to generic functions -```py +`functions.pyi`: + +```pyi from typing import Callable -def invoke[A, B](fn: Callable[[A], B], value: A) -> B: - return fn(value) +def invoke[A, B](fn: Callable[[A], B], value: A) -> B: ... -def identity[T](x: T) -> T: - return x +def identity[T](x: T) -> T: ... + +class Covariant[T]: + def get(self) -> T: + raise NotImplementedError -def head[T](xs: list[T]) -> T: - return xs[0] +def head_covariant[T](xs: Covariant[T]) -> T: ... +def lift_covariant[T](xs: T) -> Covariant[T]: ... + +class Contravariant[T]: + def receive(self, input: T): ... + +def head_contravariant[T](xs: Contravariant[T]) -> T: ... +def lift_contravariant[T](xs: T) -> Contravariant[T]: ... + +class Invariant[T]: + mutable_attribute: T + +def head_invariant[T](xs: Invariant[T]) -> T: ... +def lift_invariant[T](xs: T) -> Invariant[T]: ... +``` + +A simple function that passes through its parameter type unchanged: + +`simple.py`: + +```py +from functions import invoke, identity reveal_type(invoke(identity, 1)) # revealed: Literal[1] +``` + +When the either the parameter or the return type is a generic alias referring to the typevar, we +should still be able to propagate the specializations through. This should work regardless of the +typevar's variance in the generic alias. + +TODO: This currently only works for the `lift` functions (TODO: and only currently for the covariant +case). For the `lift` functions, the parameter type is a bare typevar, resulting in us inferring a +type mapping of `A = int, B = Class[A]`. When specializing, we can substitute the mapping of `A` +into the mapping of `B`, giving the correct return type. + +For the `head` functions, the parameter type is a generic alias, resulting in us inferring a type +mapping of `A = Class[int], A = Class[B]`. At this point, the old solver is not able to unify the +two mappings for `A`, and we have no mapping for `B`. As a result, we infer `Unknown` for the return +type. + +As part of migrating to the new solver, we will generate a single constraint set combining all of +the facts that we learn while checking the arguments. And the constraint set implementation should +be able to unify the two assignments to `A`. -# TODO: this should be `Unknown | int` -reveal_type(invoke(head, [1, 2, 3])) # revealed: Unknown +`covariant.py`: + +```py +from functions import invoke, Covariant, head_covariant, lift_covariant + +# TODO: revealed: `int` +# revealed: Unknown +reveal_type(invoke(head_covariant, Covariant[int]())) +# revealed: Covariant[Literal[1]] +reveal_type(invoke(lift_covariant, 1)) +``` + +`contravariant.py`: + +```py +from functions import invoke, Contravariant, head_contravariant, lift_contravariant + +# TODO: revealed: `int` +# revealed: Unknown +reveal_type(invoke(head_contravariant, Contravariant[int]())) +# TODO: revealed: Contravariant[int] +# revealed: Unknown +reveal_type(invoke(lift_contravariant, 1)) +``` + +`invariant.py`: + +```py +from functions import invoke, Invariant, head_invariant, lift_invariant + +# TODO: revealed: `int` +# revealed: Unknown +reveal_type(invoke(head_invariant, Invariant[int]())) +# TODO: revealed: `Invariant[int]` +# revealed: Unknown +reveal_type(invoke(lift_invariant, 1)) ``` ## Protocols as TypeVar bounds diff --git a/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md b/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md index 32956cdfa8f6a9..5c37e118bfb8d4 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md @@ -416,19 +416,43 @@ def mutually_bound[T: Base, U](): ## Nested typevars -A typevar's constraint can _mention_ another typevar without _constraining_ it. In this example, `U` -must be specialized to `list[T]`, but it cannot affect what `T` is specialized to. +The specialization of one typevar can affect the specialization of another, even if it is not a +"top-level" type in the bounds. (That is, if it appears as inside the specialization of a generic +class.) ```py from typing import Never from ty_extensions import ConstraintSet, generic_context -def mentions[T, U](): - # (T@mentions ≤ int) ∧ (U@mentions = list[T@mentions]) - constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(list[T], U, list[T]) - # TODO: revealed: ty_extensions.Specialization[T@mentions = int, U@mentions = list[int]] - # revealed: ty_extensions.Specialization[T@mentions = int, U@mentions = Unknown] - reveal_type(generic_context(mentions).specialize_constrained(constraints)) +class Covariant[T]: + def get(self) -> T: + raise NotImplementedError + +class Contravariant[T]: + def receive(self, input: T): ... + +class Invariant[T]: + mutable_attribute: T + +def mentions_covariant[T, U](): + # (T@mentions_covariant ≤ int) ∧ (U@mentions_covariant ≤ Covariant[T@mentions_covariant]) + constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(Never, U, Covariant[T]) + # revealed: ty_extensions.Specialization[T@mentions_covariant = int, U@mentions_covariant = Covariant[int]] + reveal_type(generic_context(mentions_covariant).specialize_constrained(constraints)) + +def mentions_contravariant[T, U](): + # (T@mentions_contravariant ≤ int) ∧ (Contravariant[T@mentions_contravariant] ≤ U@mentions_contravariant) + constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(Contravariant[T], U, object) + # TODO: revealed: ty_extensions.Specialization[T@mentions_contravariant = int, U@mentions_contravariant = Contravariant[int]] + # revealed: ty_extensions.Specialization[T@mentions_contravariant = int, U@mentions_contravariant = Unknown] + reveal_type(generic_context(mentions_contravariant).specialize_constrained(constraints)) + +def mentions_invariant[T, U](): + # (T@mentions_invariant ≤ int) ∧ (U@mentions_invariant = Invariant[T@mentions_invariant]) + constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(Invariant[T], U, Invariant[T]) + # TODO: revealed: ty_extensions.Specialization[T@mentions_invariant = int, U@mentions_invariant = Invariant[int]] + # revealed: ty_extensions.Specialization[T@mentions_invariant = int, U@mentions_invariant = Unknown] + reveal_type(generic_context(mentions_invariant).specialize_constrained(constraints)) ``` If the constraint set contains mutually recursive bounds, specialization inference will not @@ -437,8 +461,8 @@ this case. ```py def divergent[T, U](): - # (T@divergent = list[U@divergent]) ∧ (U@divergent = list[T@divergent])) - constraints = ConstraintSet.range(list[U], T, list[U]) & ConstraintSet.range(list[T], U, list[T]) + # (T@divergent = Invariant[U@divergent]) ∧ (U@divergent = Invariant[T@divergent])) + constraints = ConstraintSet.range(Invariant[U], T, Invariant[U]) & ConstraintSet.range(Invariant[T], U, Invariant[T]) # revealed: None reveal_type(generic_context(divergent).specialize_constrained(constraints)) ``` diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md b/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md new file mode 100644 index 00000000000000..d3877275cd032c --- /dev/null +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md @@ -0,0 +1,407 @@ +# Constraint set quantification + +```toml +[environment] +python-version = "3.12" +``` + +We can _existentially quantify_ a constraint set over a type variable. The result is a copy of the +constraint set that only mentions the requested typevar. All constraints mentioning any other +typevars are removed. Importantly, they are removed "safely", with their constraints propagated +through to the remaining constraints as needed. + +## Keeping a single typevar + +If a constraint set only mentions a single typevar, and we keep that typevar when quantifying, the +result is unchanged. + +```py +from ty_extensions import ConstraintSet, static_assert + +class Base: ... +class Sub(Base): ... + +def keep_single[T](): + constraints = ConstraintSet.always() + quantified = ConstraintSet.always() + static_assert(constraints.retain_one(T) == quantified) + + constraints = ConstraintSet.never() + quantified = ConstraintSet.never() + static_assert(constraints.retain_one(T) == quantified) + + constraints = ConstraintSet.range(Sub, T, Base) + quantified = ConstraintSet.range(Sub, T, Base) + static_assert(constraints.retain_one(T) == quantified) +``` + +## Removing a single typevar + +If a constraint set only mentions a single typevar, and we remove that typevar when quantifying, the +result is usually "always". The only exception is if the original constraint set has no solution. In +that case, the result is also unsatisfiable. + +```py +from ty_extensions import ConstraintSet, static_assert + +class Base: ... +class Sub(Base): ... + +def remove_single[T](): + constraints = ConstraintSet.always() + quantified = ConstraintSet.always() + static_assert(constraints.exists(T) == quantified) + + constraints = ConstraintSet.never() + quantified = ConstraintSet.never() + static_assert(constraints.exists(T) == quantified) + + constraints = ConstraintSet.range(Sub, T, Base) + quantified = ConstraintSet.always() + static_assert(constraints.exists(T) == quantified) +``` + +This also holds when the constraint set contains multiple typevars. In the cases below, we are +keeping `U`, and the constraints on `T` do not ever affect what `U` can specialize to — `U` can +specialize to anything (unless the original constraint set is unsatisfiable). + +```py +from ty_extensions import ConstraintSet, static_assert + +class Base: ... +class Sub(Base): ... + +def remove_other[T, U](): + constraints = ConstraintSet.always() + quantified = ConstraintSet.always() + static_assert(constraints.retain_one(U) == quantified) + + constraints = ConstraintSet.never() + quantified = ConstraintSet.never() + static_assert(constraints.retain_one(U) == quantified) + + constraints = ConstraintSet.range(Sub, T, Base) + quantified = ConstraintSet.always() + static_assert(constraints.retain_one(U) == quantified) +``` + +## Transitivity + +When a constraint set mentions two typevars, and compares them directly, then we can use +transitivity to propagate the other constraints when quantifying. + +```py +from typing import Never +from ty_extensions import ConstraintSet, static_assert + +class Super: ... +class Base(Super): ... +class Sub(Base): ... + +def transitivity[T, U](): + # (Base ≤ T) ∧ (T ≤ U) → (Base ≤ U) + constraints = ConstraintSet.range(Base, T, object) & ConstraintSet.range(T, U, object) + quantified = ConstraintSet.range(Base, U, object) + static_assert(constraints.exists(T) == quantified) + + # (Base ≤ T ≤ Super) ∧ (T ≤ U) → (Base ≤ U) + constraints = ConstraintSet.range(Base, T, Super) & ConstraintSet.range(T, U, object) + quantified = ConstraintSet.range(Base, U, object) + static_assert(constraints.exists(T) == quantified) + + # (T ≤ Base) ∧ (U ≤ T) → (U ≤ Base) + constraints = ConstraintSet.range(Never, T, Base) & ConstraintSet.range(Never, U, T) + quantified = ConstraintSet.range(Never, U, Base) + static_assert(constraints.exists(T) == quantified) + + # (Sub ≤ T ≤ Base) ∧ (U ≤ T) → (U ≤ Base) + constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Never, U, T) + quantified = ConstraintSet.range(Never, U, Base) + static_assert(constraints.exists(T) == quantified) +``` + +## Covariant transitivity + +The same applies when one of the typevars is used covariantly in a bound of the other typevar. + +```py +from typing import Never +from ty_extensions import ConstraintSet, static_assert + +class Super: ... +class Base(Super): ... +class Sub(Base): ... + +class Covariant[T]: + def get(self) -> T: + raise NotImplementedError + +def covariant_transitivity[T, U](): + # (Base ≤ T) ∧ (Covariant[T] ≤ U) → (Covariant[Base] ≤ U) + constraints = ConstraintSet.range(Base, T, object) & ConstraintSet.range(Covariant[T], U, object) + quantified = ConstraintSet.range(Covariant[Base], U, object) + static_assert(constraints.exists(T) == quantified) + + # (Base ≤ T ≤ Super) ∧ (Covariant[T] ≤ U) → (Covariant[Base] ≤ U) + constraints = ConstraintSet.range(Base, T, Super) & ConstraintSet.range(Covariant[T], U, object) + quantified = ConstraintSet.range(Covariant[Base], U, object) + static_assert(constraints.exists(T) == quantified) + + # (T ≤ Base) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[Base]) + constraints = ConstraintSet.range(Never, T, Base) & ConstraintSet.range(Never, U, Covariant[T]) + quantified = ConstraintSet.range(Never, U, Covariant[Base]) + static_assert(constraints.exists(T) == quantified) + + # (Sub ≤ T ≤ Base) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[Base]) + constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Never, U, Covariant[T]) + quantified = ConstraintSet.range(Never, U, Covariant[Base]) + static_assert(constraints.exists(T) == quantified) +``` + +Same as above, but when propagating a third typevar instead of a concrete type. We make sure to test +with both variable orderings for the constraint that involves two typevars. + +```py +def covariant_typevar_transitivity[B, T, U](): + # (B ≤ T) ∧ (Covariant[T] ≤ U) → (Covariant[B] ≤ U) + constraints = ConstraintSet.range(B, T, object) & ConstraintSet.range(Covariant[T], U, object) + quantified = ConstraintSet.range(Covariant[B], U, object) + static_assert(constraints.exists(T) == quantified) + + # (T ≤ B) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[B]) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Never, U, Covariant[T]) + quantified = ConstraintSet.range(Never, U, Covariant[B]) + static_assert(constraints.exists(T) == quantified) + +def covariant_typevar_transitivity_reversed[T, B, U](): + # (B ≤ T) ∧ (Covariant[T] ≤ U) → (Covariant[B] ≤ U) + constraints = ConstraintSet.range(B, T, object) & ConstraintSet.range(Covariant[T], U, object) + quantified = ConstraintSet.range(Covariant[B], U, object) + static_assert(constraints.exists(T) == quantified) + + # (T ≤ B) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[B]) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Never, U, Covariant[T]) + quantified = ConstraintSet.range(Never, U, Covariant[B]) + static_assert(constraints.exists(T) == quantified) +``` + +## Contravariant transitivity + +Similar rules apply, but in reverse, when one of the typevars is used contravariantly in a bound of +the other typevar. + +```py +from typing import Never +from ty_extensions import ConstraintSet, static_assert + +class Super: ... +class Base(Super): ... +class Sub(Base): ... + +class Contravariant[T]: + def receive(self, input: T): ... + +def contravariant_transitivity[T, U](): + # (Base ≤ T) ∧ (U ≤ Contravariant[T]) → (U ≤ Contravariant[Base]) + constraints = ConstraintSet.range(Base, T, object) & ConstraintSet.range(Never, U, Contravariant[T]) + quantified = ConstraintSet.range(Never, U, Contravariant[Base]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (Base ≤ T ≤ Super) ∧ (U ≤ Contravariant[T]) → (U ≤ Contravariant[Base]) + constraints = ConstraintSet.range(Base, T, Super) & ConstraintSet.range(Never, U, Contravariant[T]) + quantified = ConstraintSet.range(Never, U, Contravariant[Base]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T ≤ Base) ∧ (Contravariant[T] ≤ U) → (Contravariant[Base] ≤ U) + constraints = ConstraintSet.range(Never, T, Base) & ConstraintSet.range(Contravariant[T], U, object) + quantified = ConstraintSet.range(Contravariant[Base], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (Sub ≤ T ≤ Base) ∧ (Contravariant[T] ≤ U) → (Contravariant[Base] ≤ U) + constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Contravariant[T], U, object) + quantified = ConstraintSet.range(Contravariant[Base], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) +``` + +Same as above, but when propagating a third typevar instead of a concrete type. We make sure to test +with both variable orderings for the constraint that involves two typevars. + +```py +def contravariant_typevar_transitivity[B, T, U](): + # (B ≤ T) ∧ (U ≤ Contravariant[T]) → (U ≤ Contravariant[B]) + constraints = ConstraintSet.range(B, T, object) & ConstraintSet.range(Never, U, Contravariant[T]) + quantified = ConstraintSet.range(Never, U, Contravariant[B]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T ≤ B) ∧ (Contravariant[T] ≤ U) → (Contravariant[B] ≤ U) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Contravariant[T], U, object) + quantified = ConstraintSet.range(Contravariant[B], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + +def contravariant_typevar_transitivity_reversed[T, B, U](): + # (B ≤ T) ∧ (U ≤ Contravariant[T]) → (U ≤ Contravariant[B]) + constraints = ConstraintSet.range(B, T, object) & ConstraintSet.range(Never, U, Contravariant[T]) + quantified = ConstraintSet.range(Never, U, Contravariant[B]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T ≤ B) ∧ (Contravariant[T] ≤ U) → (Contravariant[B] ≤ U) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Contravariant[T], U, object) + quantified = ConstraintSet.range(Contravariant[B], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) +``` + +## Invariant transitivity involving equality constraints + +Invariant uses of a typevar are more subtle. The simplest case is when there is an _equality_ +constraint on the invariant typevar. In that case, we know precisely which specialization is +required. + +```py +from typing import Never +from ty_extensions import ConstraintSet, static_assert + +class Base: ... + +class Invariant[T]: + mutable_attribute: T + +def invariant_equality_transitivity[T, U](): + # (T = Base) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[Base]) + constraints = ConstraintSet.range(Base, T, Base) & ConstraintSet.range(Never, U, Invariant[T]) + quantified = ConstraintSet.range(Never, U, Invariant[Base]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T = Base) ∧ (Invariant[T] ≤ U) → (Invariant[Base] ≤ U) + constraints = ConstraintSet.range(Base, T, Base) & ConstraintSet.range(Invariant[T], U, object) + quantified = ConstraintSet.range(Invariant[Base], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) +``` + +Same as above, but when propagating a third typevar instead of a concrete type. We make sure to test +with both variable orderings for the constraint that involves two typevars. + +```py +def invariant_equality_typevar_transitivity[B, T, U](): + # (T = B) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[B]) + constraints = ConstraintSet.range(B, T, B) & ConstraintSet.range(Never, U, Invariant[T]) + quantified = ConstraintSet.range(Never, U, Invariant[B]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T = B) ∧ (Invariant[T] ≤ U) → (Invariant[B] ≤ U) + constraints = ConstraintSet.range(B, T, B) & ConstraintSet.range(Invariant[T], U, object) + quantified = ConstraintSet.range(Invariant[B], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + +def invariant_equality_typevar_transitivity_reverse[T, B, U](): + # (T = B) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[B]) + constraints = ConstraintSet.range(B, T, B) & ConstraintSet.range(Never, U, Invariant[T]) + quantified = ConstraintSet.range(Never, U, Invariant[B]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T = B) ∧ (Invariant[T] ≤ U) → (Invariant[B] ≤ U) + constraints = ConstraintSet.range(B, T, B) & ConstraintSet.range(Invariant[T], U, object) + quantified = ConstraintSet.range(Invariant[B], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) +``` + +## Invariant transitivity involving range constraints + +When there is a _range_ constraint on the invariant typevar, we still have to retain information +about which range of types the quantified-away typevar can specialize to, since this affects which +types the remaining typevar can specialize to, and invariant typevars are not monotonic like +covariant and contravariant typevars. + +```py +from typing import Never +from ty_extensions import ConstraintSet, static_assert + +class Base: ... +class Sub(Base): ... + +class Invariant[T]: + mutable_attribute: T + +def invariant_range_transitivity[T, U](): + # (Sub ≤ T ≤ Base) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[Exists[Sub, Base]]) + constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Never, U, Invariant[T]) + # TODO: The existential that we need doesn't exist yet. + quantified = ConstraintSet.never() + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (Sub ≤ T ≤ Base) ∧ (Invariant[T] ≤ U) → (Invariant[Exists[Sub, Base]] ≤ U) + constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Invariant[T], U, object) + # TODO: The existential that we need doesn't exist yet. + quantified = ConstraintSet.never() + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) +``` + +Same as above, but when propagating a third typevar instead of a concrete type. We make sure to test +with both variable orderings for the constraint that involves two typevars. + +```py +def invariant_range_typevar_transitivity[B, T, U](): + # (T ≤ B) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[Exists[Never, B]]) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Never, U, Invariant[T]) + # TODO: The existential that we need doesn't exist yet. + quantified = ConstraintSet.never() + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T ≤ B) ∧ (Invariant[T] ≤ U) → (Invariant[Exists[Never, B]] ≤ U) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Invariant[T], U, object) + # TODO: The existential that we need doesn't exist yet. + quantified = ConstraintSet.never() + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + +def invariant_range_typevar_transitivity_reverse[T, B, U](): + # (T ≤ B) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[Exists[Never, B]]) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Never, U, Invariant[T]) + # TODO: The existential that we need doesn't exist yet. + quantified = ConstraintSet.never() + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T ≤ B) ∧ (Invariant[T] ≤ U) → (Invariant[Exists[Never, B]] ≤ U) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Invariant[T], U, object) + # TODO: The existential that we need doesn't exist yet. + quantified = ConstraintSet.never() + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) +``` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index e4a0228cd41b54..85c1bb3091308d 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -5026,6 +5026,12 @@ impl<'db> Type<'db> { )) .into() } + Type::KnownInstance(KnownInstanceType::ConstraintSet(tracked)) if name == "exists" => { + Place::bound(Type::KnownBoundMethod( + KnownBoundMethodType::ConstraintSetExists(tracked), + )) + .into() + } Type::KnownInstance(KnownInstanceType::ConstraintSet(tracked)) if name == "implies_subtype_of" => { @@ -5034,6 +5040,14 @@ impl<'db> Type<'db> { )) .into() } + Type::KnownInstance(KnownInstanceType::ConstraintSet(tracked)) + if name == "retain_one" => + { + Place::bound(Type::KnownBoundMethod( + KnownBoundMethodType::ConstraintSetRetainOne(tracked), + )) + .into() + } Type::KnownInstance(KnownInstanceType::ConstraintSet(tracked)) if name == "satisfies" => { @@ -8035,7 +8049,9 @@ impl<'db> Type<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_) @@ -8254,7 +8270,9 @@ impl<'db> Type<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_), @@ -12674,7 +12692,9 @@ pub enum KnownBoundMethodType<'db> { ConstraintSetRange, ConstraintSetAlways, ConstraintSetNever, + ConstraintSetExists(TrackedConstraintSet<'db>), ConstraintSetImpliesSubtypeOf(TrackedConstraintSet<'db>), + ConstraintSetRetainOne(TrackedConstraintSet<'db>), ConstraintSetSatisfies(TrackedConstraintSet<'db>), ConstraintSetSatisfiedByAllTypeVars(TrackedConstraintSet<'db>), @@ -12706,7 +12726,9 @@ pub(super) fn walk_method_wrapper_type<'db, V: visitor::TypeVisitor<'db> + ?Size KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_) => {} @@ -12773,10 +12795,18 @@ impl<'db> KnownBoundMethodType<'db> { KnownBoundMethodType::ConstraintSetNever, KnownBoundMethodType::ConstraintSetNever, ) + | ( + KnownBoundMethodType::ConstraintSetExists(_), + KnownBoundMethodType::ConstraintSetExists(_), + ) | ( KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_), KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_), ) + | ( + KnownBoundMethodType::ConstraintSetRetainOne(_), + KnownBoundMethodType::ConstraintSetRetainOne(_), + ) | ( KnownBoundMethodType::ConstraintSetSatisfies(_), KnownBoundMethodType::ConstraintSetSatisfies(_), @@ -12799,7 +12829,9 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_), @@ -12811,7 +12843,9 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_), @@ -12864,9 +12898,17 @@ impl<'db> KnownBoundMethodType<'db> { ) => ConstraintSet::from(true), ( + KnownBoundMethodType::ConstraintSetExists(left_constraints), + KnownBoundMethodType::ConstraintSetExists(right_constraints), + ) + | ( KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(left_constraints), KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(right_constraints), ) + | ( + KnownBoundMethodType::ConstraintSetRetainOne(left_constraints), + KnownBoundMethodType::ConstraintSetRetainOne(right_constraints), + ) | ( KnownBoundMethodType::ConstraintSetSatisfies(left_constraints), KnownBoundMethodType::ConstraintSetSatisfies(right_constraints), @@ -12892,7 +12934,9 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_), @@ -12904,7 +12948,9 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_), @@ -12930,7 +12976,9 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_) => self, @@ -12968,7 +13016,9 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_) => Some(self), @@ -12986,7 +13036,9 @@ impl<'db> KnownBoundMethodType<'db> { KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_) => { @@ -13132,6 +13184,18 @@ impl<'db> KnownBoundMethodType<'db> { ))) } + KnownBoundMethodType::ConstraintSetExists(_) => { + Either::Right(std::iter::once(Signature::new( + Parameters::new( + db, + [Parameter::variadic(Name::new_static("typevars")) + .type_form() + .with_annotated_type(Type::any())], + ), + Some(KnownClass::ConstraintSet.to_instance(db)), + ))) + } + KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) => { Either::Right(std::iter::once(Signature::new( Parameters::new( @@ -13149,6 +13213,20 @@ impl<'db> KnownBoundMethodType<'db> { ))) } + KnownBoundMethodType::ConstraintSetRetainOne(_) => { + Either::Right(std::iter::once(Signature::new( + Parameters::new( + db, + [ + Parameter::positional_only(Some(Name::new_static("typevar"))) + .type_form() + .with_annotated_type(Type::any()), + ], + ), + Some(KnownClass::ConstraintSet.to_instance(db)), + ))) + } + KnownBoundMethodType::ConstraintSetSatisfies(_) => { Either::Right(std::iter::once(Signature::new( Parameters::new( diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index cd50607c146722..c17a2c0c2c760f 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -1255,6 +1255,25 @@ impl<'db> Bindings<'db> { )); } + Type::KnownBoundMethod(KnownBoundMethodType::ConstraintSetExists(tracked)) => { + let typevars: Option> = overload + .arguments_for_parameter(argument_types, 0) + .map(|(_, ty)| { + ty.as_typevar() + .map(|bound_typevar| bound_typevar.identity(db)) + }) + .collect(); + let Some(typevars) = typevars else { + continue; + }; + + let result = tracked.constraints(db).exists(db, typevars); + let tracked = TrackedConstraintSet::new(db, result); + overload.set_return_type(Type::KnownInstance( + KnownInstanceType::ConstraintSet(tracked), + )); + } + Type::KnownBoundMethod( KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(tracked), ) => { @@ -1274,6 +1293,20 @@ impl<'db> Bindings<'db> { )); } + Type::KnownBoundMethod(KnownBoundMethodType::ConstraintSetRetainOne( + tracked, + )) => { + let [Some(Type::TypeVar(typevar))] = overload.parameter_types() else { + continue; + }; + + let result = tracked.constraints(db).retain_one(db, typevar.identity(db)); + let tracked = TrackedConstraintSet::new(db, result); + overload.set_return_type(Type::KnownInstance( + KnownInstanceType::ConstraintSet(tracked), + )); + } + Type::KnownBoundMethod(KnownBoundMethodType::ConstraintSetSatisfies( tracked, )) => { diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 77b96bd74bbc14..b7474145257c41 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -75,13 +75,16 @@ use itertools::Itertools; use rustc_hash::{FxHashMap, FxHashSet}; use salsa::plumbing::AsId; -use crate::types::generics::{GenericContext, InferableTypeVars, Specialization}; +use crate::types::generics::{ + GenericContext, InferableTypeVars, PartialSpecialization, Specialization, +}; +use crate::types::variance::VarianceInferable; use crate::types::visitor::{ TypeCollector, TypeVisitor, any_over_type, walk_type_with_recursion_guard, }; use crate::types::{ - BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeVarBoundOrConstraints, - UnionType, walk_bound_type_var_type, + BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeContext, TypeMapping, + TypeVarBoundOrConstraints, TypeVarVariance, UnionType, walk_bound_type_var_type, }; use crate::{Db, FxOrderMap}; @@ -432,6 +435,29 @@ impl<'db> ConstraintSet<'db> { } } + /// Returns a new constraint set that is the _existential abstraction_ of `self` for a set of + /// typevars. The result will return true whenever `self` returns true for _any_ assignment of + /// those typevars. The result will not contain any constraints that mention those typevars. + pub(crate) fn exists( + self, + db: &'db dyn Db, + bound_typevars: impl IntoIterator>, + ) -> Self { + let node = self.node.exists(db, bound_typevars); + Self { node } + } + + /// Quantifies over this constraint set so that it only contains constraints that mention the + /// given typevar. All other typevars are quantified away. + pub(crate) fn retain_one( + self, + db: &'db dyn Db, + bound_typevar: BoundTypeVarIdentity<'db>, + ) -> Self { + let node = self.node.retain_one(db, bound_typevar); + Self { node } + } + /// Reduces the set of inferable typevars for this constraint set. You provide an iterator of /// the typevars that were inferable when this constraint set was created, and which should be /// abstracted away. Those typevars will be removed from the constraint set, and the constraint @@ -3027,30 +3053,59 @@ impl<'db> SequentMap<'db> { left_constraint: ConstrainedTypeVar<'db>, right_constraint: ConstrainedTypeVar<'db>, ) { - // We've structured our constraints so that a typevar's upper/lower bound can only - // be another typevar if the bound is "later" in our arbitrary ordering. That means - // we only have to check this pair of constraints in one direction — though we do - // have to figure out which of the two typevars is constrained, and which one is - // the upper/lower bound. + // Add sequents when the two typevars are mutually constrained directly — that is, when the + // lower/upper bound _is_ the typevar, not _contains_ the typevar. We've structured our + // constraints so that a typevar's upper/lower bound can only be another typevar if the + // bound is "later" in our arbitrary ordering. That means we only have to check this pair + // of constraints in one direction — though we do have to figure out which of the two + // typevars is constrained, and which one is the upper/lower bound. let left_typevar = left_constraint.typevar(db); let right_typevar = right_constraint.typevar(db); - let (bound_typevar, bound_constraint, constrained_typevar, constrained_constraint) = - if left_typevar.can_be_bound_for(db, right_typevar) { - ( - left_typevar, - left_constraint, - right_typevar, - right_constraint, - ) - } else { - ( - right_typevar, - right_constraint, - left_typevar, - left_constraint, - ) - }; + if left_typevar.can_be_bound_for(db, right_typevar) { + self.add_direct_mutual_sequents_for_different_typevars( + db, + left_constraint, + left_typevar, + right_constraint, + right_typevar, + ); + } else { + self.add_direct_mutual_sequents_for_different_typevars( + db, + right_constraint, + right_typevar, + left_constraint, + left_typevar, + ); + } + + // Add sequents when one of the typevars is mentioned deeply inside the bounds of the + // other. Because the "bound" typevar appears deeply inside of the "constrained" typevar, + // our constraint ordering doesn't apply, and we have to try each direction. + self.add_deep_mutual_sequents( + db, + left_constraint, + left_typevar, + right_constraint, + right_typevar, + ); + self.add_deep_mutual_sequents( + db, + right_constraint, + right_typevar, + left_constraint, + left_typevar, + ); + } + fn add_direct_mutual_sequents_for_different_typevars( + &mut self, + db: &'db dyn Db, + bound_constraint: ConstrainedTypeVar<'db>, + bound_typevar: BoundTypeVarInstance<'db>, + constrained_constraint: ConstrainedTypeVar<'db>, + constrained_typevar: BoundTypeVarInstance<'db>, + ) { // We then look for cases where the "constrained" typevar's upper and/or lower bound // matches the "bound" typevar. If so, we're going to add an implication sequent that // replaces the upper/lower bound that matched with the bound constraint's corresponding @@ -3104,10 +3159,157 @@ impl<'db> SequentMap<'db> { let post_constraint = ConstrainedTypeVar::new(db, constrained_typevar, new_lower, new_upper); - self.add_pair_implication(db, left_constraint, right_constraint, post_constraint); + self.add_pair_implication( + db, + bound_constraint, + constrained_constraint, + post_constraint, + ); self.enqueue_constraint(post_constraint); } + fn add_deep_mutual_sequents( + &mut self, + db: &'db dyn Db, + bound_constraint: ConstrainedTypeVar<'db>, + bound_typevar: BoundTypeVarInstance<'db>, + constrained_constraint: ConstrainedTypeVar<'db>, + constrained_typevar: BoundTypeVarInstance<'db>, + ) { + let bound_lower = bound_constraint.lower(db); + let bound_upper = bound_constraint.upper(db); + let constrained_lower = constrained_constraint.lower(db); + let constrained_upper = constrained_constraint.upper(db); + + let mut add_constraint = |post_constraint| { + self.add_pair_implication( + db, + bound_constraint, + constrained_constraint, + post_constraint, + ); + self.enqueue_constraint(post_constraint); + }; + + match constrained_lower.variance_of(db, bound_typevar) { + // B does not appear in CU, or if it does, it appears bivariantly. The constraints of B + // do not affect the valid specializations of C. + TypeVarVariance::Bivariant => {} + + // (Covariant[B] ≤ C ≤ CU) ∧ (BL ≤ B ≤ BU) → (Covariant[BL] ≤ C ≤ CU) + TypeVarVariance::Covariant => { + // Only substitute to create a new sequent if the substitution is interesting, and + // doesn't recursively contain the typevar we are substituting for. + if !bound_lower.is_never() + && !bound_lower.is_object() + && bound_lower.variance_of(db, bound_typevar) == TypeVarVariance::Bivariant + { + let partial = PartialSpecialization::Single { + bound_typevar, + ty: bound_lower, + }; + let new_lower = constrained_lower.apply_type_mapping( + db, + &TypeMapping::PartialSpecialization(partial), + TypeContext::default(), + ); + add_constraint(ConstrainedTypeVar::new( + db, + constrained_typevar, + new_lower, + constrained_upper, + )); + } + } + + // TODO + TypeVarVariance::Contravariant | TypeVarVariance::Invariant => {} + } + + // (Covariant[BU] ≤ C ≤ CU) ∧ (BL ≤ B ≤ BU) → (Covariant[B] ≤ C ≤ CU) + if let Type::TypeVar(bound_upper_typevar) = bound_upper + && !bound_upper_typevar.is_same_typevar_as(db, constrained_typevar) + { + if constrained_lower.variance_of(db, bound_upper_typevar) == TypeVarVariance::Covariant + { + let partial = PartialSpecialization::Single { + bound_typevar: bound_upper_typevar, + ty: Type::TypeVar(bound_typevar), + }; + let new_lower = constrained_lower.apply_type_mapping( + db, + &TypeMapping::PartialSpecialization(partial), + TypeContext::default(), + ); + add_constraint(ConstrainedTypeVar::new( + db, + constrained_typevar, + new_lower, + constrained_upper, + )); + } + } + + match constrained_upper.variance_of(db, bound_typevar) { + // B does not appear in CU, or if it does, it appears bivariantly. The constraints of B + // do not affect the valid specializations of C. + TypeVarVariance::Bivariant => {} + + // (CL ≤ C ≤ Covariant[B]) ∧ (BL ≤ B ≤ BU) → (CL ≤ C ≤ Covariant[BU]) + TypeVarVariance::Covariant => { + // Only substitute to create a new sequent if the substitution is interesting, and + // doesn't recursively contain the typevar we are substituting for. + if !bound_upper.is_never() + && !bound_upper.is_object() + && bound_upper.variance_of(db, bound_typevar) == TypeVarVariance::Bivariant + { + let partial = PartialSpecialization::Single { + bound_typevar, + ty: bound_upper, + }; + let new_upper = constrained_upper.apply_type_mapping( + db, + &TypeMapping::PartialSpecialization(partial), + TypeContext::default(), + ); + add_constraint(ConstrainedTypeVar::new( + db, + constrained_typevar, + constrained_lower, + new_upper, + )); + } + } + + // TODO + TypeVarVariance::Contravariant | TypeVarVariance::Invariant => {} + } + + // (CL ≤ C ≤ Covariant[BL]) ∧ (BL ≤ B ≤ BU) → (CL ≤ C ≤ Covariant[B]) + if let Type::TypeVar(bound_lower_typevar) = bound_lower + && !bound_lower_typevar.is_same_typevar_as(db, constrained_typevar) + { + if constrained_upper.variance_of(db, bound_lower_typevar) == TypeVarVariance::Covariant + { + let partial = PartialSpecialization::Single { + bound_typevar: bound_lower_typevar, + ty: Type::TypeVar(bound_typevar), + }; + let new_upper = constrained_upper.apply_type_mapping( + db, + &TypeMapping::PartialSpecialization(partial), + TypeContext::default(), + ); + add_constraint(ConstrainedTypeVar::new( + db, + constrained_typevar, + constrained_lower, + new_upper, + )); + } + } + } + fn add_mutual_sequents_for_same_typevars( &mut self, db: &'db dyn Db, diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index 573756d1b215a2..e8e9d4662520ce 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -838,9 +838,15 @@ impl<'db> FmtDetailed<'db> for DisplayRepresentation<'db> { KnownBoundMethodType::ConstraintSetNever => { return f.write_str("bound method `ConstraintSet.never`"); } + KnownBoundMethodType::ConstraintSetExists(_) => { + return f.write_str("bound method `ConstraintSet.exists`"); + } KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) => { return f.write_str("bound method `ConstraintSet.implies_subtype_of`"); } + KnownBoundMethodType::ConstraintSetRetainOne(_) => { + return f.write_str("bound method `ConstraintSet.retain_one`"); + } KnownBoundMethodType::ConstraintSetSatisfies(_) => { return f.write_str("bound method `ConstraintSet.satisfies`"); } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index ca6a700bd240f0..3b728ee4578c04 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -568,7 +568,7 @@ impl<'db> GenericContext<'db> { loop { let mut any_changed = false; for i in 0..len { - let partial = PartialSpecialization { + let partial = PartialSpecialization::FromGenericContext { generic_context: self, types: &types, // Don't recursively substitute type[i] in itself. Ideally, we could instead @@ -646,7 +646,7 @@ impl<'db> GenericContext<'db> { // Typevars are only allowed to refer to _earlier_ typevars in their defaults. (This is // statically enforced for PEP-695 contexts, and is explicitly called out as a // requirement for legacy contexts.) - let partial = PartialSpecialization { + let partial = PartialSpecialization::FromGenericContext { generic_context: self, types: &expanded[0..idx], skip: None, @@ -1452,12 +1452,18 @@ impl<'db> Specialization<'db> { /// You will usually use [`Specialization`] instead of this type. This type is used when we need to /// substitute types for type variables before we have fully constructed a [`Specialization`]. #[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize)] -pub struct PartialSpecialization<'a, 'db> { - generic_context: GenericContext<'db>, - types: &'a [Type<'db>], - /// An optional typevar to _not_ substitute when applying the specialization. We use this to - /// avoid recursively substituting a type inside of itself. - skip: Option, +pub enum PartialSpecialization<'a, 'db> { + FromGenericContext { + generic_context: GenericContext<'db>, + types: &'a [Type<'db>], + /// An optional typevar to _not_ substitute when applying the specialization. We use this to + /// avoid recursively substituting a type inside of itself. + skip: Option, + }, + Single { + bound_typevar: BoundTypeVarInstance<'db>, + ty: Type<'db>, + }, } impl<'db> PartialSpecialization<'_, 'db> { @@ -1466,16 +1472,30 @@ impl<'db> PartialSpecialization<'_, 'db> { pub(crate) fn get( &self, db: &'db dyn Db, - bound_typevar: BoundTypeVarInstance<'db>, + needle_bound_typevar: BoundTypeVarInstance<'db>, ) -> Option> { - let index = self - .generic_context - .variables_inner(db) - .get_index_of(&bound_typevar.identity(db))?; - if self.skip.is_some_and(|skip| skip == index) { - return Some(Type::Never); + match self { + PartialSpecialization::FromGenericContext { + generic_context, + types, + skip, + } => { + let index = generic_context + .variables_inner(db) + .get_index_of(&needle_bound_typevar.identity(db))?; + if skip.is_some_and(|skip| skip == index) { + return Some(Type::Never); + } + types.get(index).copied() + } + PartialSpecialization::Single { bound_typevar, ty } => { + if bound_typevar.is_same_typevar_as(db, needle_bound_typevar) { + Some(*ty) + } else { + None + } + } } - self.types.get(index).copied() } } @@ -1540,7 +1560,7 @@ impl<'db> SpecializationBuilder<'db> { bound_typevar: BoundTypeVarInstance<'db>, ty: Type<'db>, variance: TypeVarVariance, - mut f: impl FnMut(TypeVarAssignment<'db>) -> Option>, + f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option>, ) { let identity = bound_typevar.identity(self.db); let Some(ty) = f((identity, variance, ty)) else { @@ -1582,7 +1602,7 @@ impl<'db> SpecializationBuilder<'db> { &mut self, formal: Type<'db>, constraints: ConstraintSet<'db>, - mut f: impl FnMut(TypeVarAssignment<'db>) -> Option>, + f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option>, ) { #[derive(Default)] struct Bounds<'db> { @@ -1626,16 +1646,20 @@ impl<'db> SpecializationBuilder<'db> { } for (bound_typevar, bounds) in mappings.drain() { - let variance = formal.variance_of(self.db, bound_typevar); - let upper = IntersectionType::from_elements(self.db, bounds.upper); - if !upper.is_object() { - self.add_type_mapping(bound_typevar, upper, variance, &mut f); + let try_upper = || { + let upper = IntersectionType::from_elements(self.db, bounds.upper); + (!upper.is_object()).then_some(upper) + }; + let try_lower = || { + let lower = UnionType::from_elements(self.db, bounds.lower); + (!lower.is_never()).then_some(lower) + }; + let Some(mapped_type) = try_upper().or_else(try_lower) else { continue; - } - let lower = UnionType::from_elements(self.db, bounds.lower); - if !lower.is_never() { - self.add_type_mapping(bound_typevar, lower, variance, &mut f); - } + }; + + let variance = formal.variance_of(self.db, bound_typevar); + self.add_type_mapping(bound_typevar, mapped_type, variance, f); } } } @@ -1667,7 +1691,7 @@ impl<'db> SpecializationBuilder<'db> { formal: Type<'db>, actual: Type<'db>, polarity: TypeVarVariance, - mut f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option>, + f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option>, ) -> Result<(), SpecializationError<'db>> { // TODO: Eventually, the builder will maintain a constraint set, instead of a hash-map of // type mappings, to represent the specialization that we are building up. At that point, @@ -1779,7 +1803,7 @@ impl<'db> SpecializationBuilder<'db> { let mut first_error = None; let mut found_matching_element = false; for formal_element in union_formal.elements(self.db) { - let result = self.infer_map_impl(*formal_element, actual, polarity, &mut f); + let result = self.infer_map_impl(*formal_element, actual, polarity, f); if let Err(err) = result { first_error.get_or_insert(err); } else { @@ -1882,7 +1906,7 @@ impl<'db> SpecializationBuilder<'db> { formal_tuple.all_elements().zip(actual_tuple.all_elements()) { let variance = TypeVarVariance::Covariant.compose(polarity); - self.infer_map_impl(*formal_element, *actual_element, variance, &mut f)?; + self.infer_map_impl(*formal_element, *actual_element, variance, f)?; } return Ok(()); } @@ -1925,7 +1949,7 @@ impl<'db> SpecializationBuilder<'db> { base_specialization ) { let variance = typevar.variance_with_polarity(self.db, polarity); - self.infer_map_impl(*formal_ty, *base_ty, variance, &mut f)?; + self.infer_map_impl(*formal_ty, *base_ty, variance, f)?; } return Ok(()); } @@ -1949,7 +1973,7 @@ impl<'db> SpecializationBuilder<'db> { formal_callable, self.inferable, ); - self.add_type_mappings_from_constraint_set(formal, when, &mut f); + self.add_type_mappings_from_constraint_set(formal, when, f); } else { for actual_signature in &actual_callable.signatures(self.db).overloads { let when = actual_signature @@ -1958,7 +1982,7 @@ impl<'db> SpecializationBuilder<'db> { formal_callable, self.inferable, ); - self.add_type_mappings_from_constraint_set(formal, when, &mut f); + self.add_type_mappings_from_constraint_set(formal, when, f); } } } diff --git a/crates/ty_vendored/ty_extensions/ty_extensions.pyi b/crates/ty_vendored/ty_extensions/ty_extensions.pyi index 347b6b4b3491d4..1ab24ca55592e7 100644 --- a/crates/ty_vendored/ty_extensions/ty_extensions.pyi +++ b/crates/ty_vendored/ty_extensions/ty_extensions.pyi @@ -59,6 +59,14 @@ class ConstraintSet: def never() -> Self: """Returns a constraint set that is never satisfied""" + def exists(self, *typevars: Any) -> Self: + """ + Returns a new constraint set that is the _existential abstraction_ of + `self` for a set of typevars. The result will return true whenever + `self` returns true for _any_ assignment of those typevars. The result + will not contain any constraints that mention those typevars. + """ + def implies_subtype_of(self, ty: Any, of: Any) -> Self: """ Returns a constraint set that is satisfied when `ty` is a `subtype`_ of @@ -67,6 +75,12 @@ class ConstraintSet: .. _subtype: https://typing.python.org/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence """ + def retain_one(self, typevar: Any) -> Self: + """ + Quantifies over this constraint set so that it only contains constraints + that mention `typevar`. All other typevars are quantified away. + """ + def satisfies(self, other: Self) -> Self: """ Returns whether this constraint set satisfies another — that is, whether