From 6d4681e65db4b5908ddaa4c3c8f8aa0269f67315 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Wed, 17 Dec 2025 16:32:13 -0500 Subject: [PATCH 01/16] add `retain_one` extension method --- crates/ty_python_semantic/src/types.rs | 41 +++++++++++++++++++ .../ty_python_semantic/src/types/call/bind.rs | 14 +++++++ .../src/types/constraints.rs | 11 +++++ .../ty_python_semantic/src/types/display.rs | 3 ++ .../ty_extensions/ty_extensions.pyi | 6 +++ 5 files changed, 75 insertions(+) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index e4a0228cd41b5..6a1245e87ebc2 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -5034,6 +5034,14 @@ impl<'db> Type<'db> { )) .into() } + Type::KnownInstance(KnownInstanceType::ConstraintSet(tracked)) + if name == "retain_one" => + { + Place::bound(Type::KnownBoundMethod( + KnownBoundMethodType::ConstraintSetRetainOne(tracked), + )) + .into() + } Type::KnownInstance(KnownInstanceType::ConstraintSet(tracked)) if name == "satisfies" => { @@ -8036,6 +8044,7 @@ impl<'db> Type<'db> { | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_) @@ -8255,6 +8264,7 @@ impl<'db> Type<'db> { | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_), @@ -12675,6 +12685,7 @@ pub enum KnownBoundMethodType<'db> { ConstraintSetAlways, ConstraintSetNever, ConstraintSetImpliesSubtypeOf(TrackedConstraintSet<'db>), + ConstraintSetRetainOne(TrackedConstraintSet<'db>), ConstraintSetSatisfies(TrackedConstraintSet<'db>), ConstraintSetSatisfiedByAllTypeVars(TrackedConstraintSet<'db>), @@ -12707,6 +12718,7 @@ pub(super) fn walk_method_wrapper_type<'db, V: visitor::TypeVisitor<'db> + ?Size | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_) => {} @@ -12777,6 +12789,10 @@ impl<'db> KnownBoundMethodType<'db> { KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_), KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_), ) + | ( + KnownBoundMethodType::ConstraintSetRetainOne(_), + KnownBoundMethodType::ConstraintSetRetainOne(_), + ) | ( KnownBoundMethodType::ConstraintSetSatisfies(_), KnownBoundMethodType::ConstraintSetSatisfies(_), @@ -12800,6 +12816,7 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_), @@ -12812,6 +12829,7 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_), @@ -12867,6 +12885,10 @@ impl<'db> KnownBoundMethodType<'db> { KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(left_constraints), KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(right_constraints), ) + | ( + KnownBoundMethodType::ConstraintSetRetainOne(left_constraints), + KnownBoundMethodType::ConstraintSetRetainOne(right_constraints), + ) | ( KnownBoundMethodType::ConstraintSetSatisfies(left_constraints), KnownBoundMethodType::ConstraintSetSatisfies(right_constraints), @@ -12893,6 +12915,7 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_), @@ -12905,6 +12928,7 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_), @@ -12931,6 +12955,7 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_) => self, @@ -12969,6 +12994,7 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_) => Some(self), @@ -12987,6 +13013,7 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) | KnownBoundMethodType::GenericContextSpecializeConstrained(_) => { @@ -13149,6 +13176,20 @@ impl<'db> KnownBoundMethodType<'db> { ))) } + KnownBoundMethodType::ConstraintSetRetainOne(_) => { + Either::Right(std::iter::once(Signature::new( + Parameters::new( + db, + [ + Parameter::positional_only(Some(Name::new_static("typevar"))) + .type_form() + .with_annotated_type(Type::any()), + ], + ), + Some(KnownClass::ConstraintSet.to_instance(db)), + ))) + } + KnownBoundMethodType::ConstraintSetSatisfies(_) => { Either::Right(std::iter::once(Signature::new( Parameters::new( diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index cd50607c14672..c05a5db9d39bd 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -1274,6 +1274,20 @@ impl<'db> Bindings<'db> { )); } + Type::KnownBoundMethod(KnownBoundMethodType::ConstraintSetRetainOne( + tracked, + )) => { + let [Some(Type::TypeVar(typevar))] = overload.parameter_types() else { + continue; + }; + + let result = tracked.constraints(db).retain_one(db, typevar.identity(db)); + let tracked = TrackedConstraintSet::new(db, result); + overload.set_return_type(Type::KnownInstance( + KnownInstanceType::ConstraintSet(tracked), + )); + } + Type::KnownBoundMethod(KnownBoundMethodType::ConstraintSetSatisfies( tracked, )) => { diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 77b96bd74bbc1..2132062a7c7d8 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -432,6 +432,17 @@ impl<'db> ConstraintSet<'db> { } } + /// Quantifies over this constraint set so that it only contains constraints that mention the + /// given typevar. All other typevars are quantified away. + pub(crate) fn retain_one( + self, + db: &'db dyn Db, + bound_typevar: BoundTypeVarIdentity<'db>, + ) -> Self { + let node = self.node.retain_one(db, bound_typevar); + Self { node } + } + /// Reduces the set of inferable typevars for this constraint set. You provide an iterator of /// the typevars that were inferable when this constraint set was created, and which should be /// abstracted away. Those typevars will be removed from the constraint set, and the constraint diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index 573756d1b215a..c1811f8b4f594 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -841,6 +841,9 @@ impl<'db> FmtDetailed<'db> for DisplayRepresentation<'db> { KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) => { return f.write_str("bound method `ConstraintSet.implies_subtype_of`"); } + KnownBoundMethodType::ConstraintSetRetainOne(_) => { + return f.write_str("bound method `ConstraintSet.retain_one`"); + } KnownBoundMethodType::ConstraintSetSatisfies(_) => { return f.write_str("bound method `ConstraintSet.satisfies`"); } diff --git a/crates/ty_vendored/ty_extensions/ty_extensions.pyi b/crates/ty_vendored/ty_extensions/ty_extensions.pyi index 347b6b4b3491d..a2dab6fdb3e9c 100644 --- a/crates/ty_vendored/ty_extensions/ty_extensions.pyi +++ b/crates/ty_vendored/ty_extensions/ty_extensions.pyi @@ -67,6 +67,12 @@ class ConstraintSet: .. _subtype: https://typing.python.org/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence """ + def retain_one(self, typevar: Any) -> Self: + """ + Quantifies over this constraint set so that it only contains constraints + that mention `typevar`. All other typevars are quantified away. + """ + def satisfies(self, other: Self) -> Self: """ Returns whether this constraint set satisfies another — that is, whether From b051943504ff3c2834d3c0b163dce65fc976c7bb Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Thu, 18 Dec 2025 09:00:07 -0500 Subject: [PATCH 02/16] add exists too --- crates/ty_python_semantic/src/types.rs | 37 +++++++++++++++++++ .../ty_python_semantic/src/types/call/bind.rs | 19 ++++++++++ .../src/types/constraints.rs | 12 ++++++ .../ty_python_semantic/src/types/display.rs | 3 ++ .../ty_extensions/ty_extensions.pyi | 8 ++++ 5 files changed, 79 insertions(+) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 6a1245e87ebc2..85c1bb3091308 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -5026,6 +5026,12 @@ impl<'db> Type<'db> { )) .into() } + Type::KnownInstance(KnownInstanceType::ConstraintSet(tracked)) if name == "exists" => { + Place::bound(Type::KnownBoundMethod( + KnownBoundMethodType::ConstraintSetExists(tracked), + )) + .into() + } Type::KnownInstance(KnownInstanceType::ConstraintSet(tracked)) if name == "implies_subtype_of" => { @@ -8043,6 +8049,7 @@ impl<'db> Type<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) @@ -8263,6 +8270,7 @@ impl<'db> Type<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) @@ -12684,6 +12692,7 @@ pub enum KnownBoundMethodType<'db> { ConstraintSetRange, ConstraintSetAlways, ConstraintSetNever, + ConstraintSetExists(TrackedConstraintSet<'db>), ConstraintSetImpliesSubtypeOf(TrackedConstraintSet<'db>), ConstraintSetRetainOne(TrackedConstraintSet<'db>), ConstraintSetSatisfies(TrackedConstraintSet<'db>), @@ -12717,6 +12726,7 @@ pub(super) fn walk_method_wrapper_type<'db, V: visitor::TypeVisitor<'db> + ?Size KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) @@ -12785,6 +12795,10 @@ impl<'db> KnownBoundMethodType<'db> { KnownBoundMethodType::ConstraintSetNever, KnownBoundMethodType::ConstraintSetNever, ) + | ( + KnownBoundMethodType::ConstraintSetExists(_), + KnownBoundMethodType::ConstraintSetExists(_), + ) | ( KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_), KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_), @@ -12815,6 +12829,7 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) @@ -12828,6 +12843,7 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) @@ -12882,6 +12898,10 @@ impl<'db> KnownBoundMethodType<'db> { ) => ConstraintSet::from(true), ( + KnownBoundMethodType::ConstraintSetExists(left_constraints), + KnownBoundMethodType::ConstraintSetExists(right_constraints), + ) + | ( KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(left_constraints), KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(right_constraints), ) @@ -12914,6 +12934,7 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) @@ -12927,6 +12948,7 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) @@ -12954,6 +12976,7 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) @@ -12993,6 +13016,7 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) @@ -13012,6 +13036,7 @@ impl<'db> KnownBoundMethodType<'db> { KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever + | KnownBoundMethodType::ConstraintSetExists(_) | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) | KnownBoundMethodType::ConstraintSetRetainOne(_) | KnownBoundMethodType::ConstraintSetSatisfies(_) @@ -13159,6 +13184,18 @@ impl<'db> KnownBoundMethodType<'db> { ))) } + KnownBoundMethodType::ConstraintSetExists(_) => { + Either::Right(std::iter::once(Signature::new( + Parameters::new( + db, + [Parameter::variadic(Name::new_static("typevars")) + .type_form() + .with_annotated_type(Type::any())], + ), + Some(KnownClass::ConstraintSet.to_instance(db)), + ))) + } + KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) => { Either::Right(std::iter::once(Signature::new( Parameters::new( diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index c05a5db9d39bd..c17a2c0c2c760 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -1255,6 +1255,25 @@ impl<'db> Bindings<'db> { )); } + Type::KnownBoundMethod(KnownBoundMethodType::ConstraintSetExists(tracked)) => { + let typevars: Option> = overload + .arguments_for_parameter(argument_types, 0) + .map(|(_, ty)| { + ty.as_typevar() + .map(|bound_typevar| bound_typevar.identity(db)) + }) + .collect(); + let Some(typevars) = typevars else { + continue; + }; + + let result = tracked.constraints(db).exists(db, typevars); + let tracked = TrackedConstraintSet::new(db, result); + overload.set_return_type(Type::KnownInstance( + KnownInstanceType::ConstraintSet(tracked), + )); + } + Type::KnownBoundMethod( KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(tracked), ) => { diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 2132062a7c7d8..58058db1bd15f 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -432,6 +432,18 @@ impl<'db> ConstraintSet<'db> { } } + /// Returns a new constraint set that is the _existential abstraction_ of `self` for a set of + /// typevars. The result will return true whenever `self` returns true for _any_ assignment of + /// those typevars. The result will not contain any constraints that mention those typevars. + pub(crate) fn exists( + self, + db: &'db dyn Db, + bound_typevars: impl IntoIterator>, + ) -> Self { + let node = self.node.exists(db, bound_typevars); + Self { node } + } + /// Quantifies over this constraint set so that it only contains constraints that mention the /// given typevar. All other typevars are quantified away. pub(crate) fn retain_one( diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index c1811f8b4f594..e8e9d4662520c 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -838,6 +838,9 @@ impl<'db> FmtDetailed<'db> for DisplayRepresentation<'db> { KnownBoundMethodType::ConstraintSetNever => { return f.write_str("bound method `ConstraintSet.never`"); } + KnownBoundMethodType::ConstraintSetExists(_) => { + return f.write_str("bound method `ConstraintSet.exists`"); + } KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) => { return f.write_str("bound method `ConstraintSet.implies_subtype_of`"); } diff --git a/crates/ty_vendored/ty_extensions/ty_extensions.pyi b/crates/ty_vendored/ty_extensions/ty_extensions.pyi index a2dab6fdb3e9c..1ab24ca55592e 100644 --- a/crates/ty_vendored/ty_extensions/ty_extensions.pyi +++ b/crates/ty_vendored/ty_extensions/ty_extensions.pyi @@ -59,6 +59,14 @@ class ConstraintSet: def never() -> Self: """Returns a constraint set that is never satisfied""" + def exists(self, *typevars: Any) -> Self: + """ + Returns a new constraint set that is the _existential abstraction_ of + `self` for a set of typevars. The result will return true whenever + `self` returns true for _any_ assignment of those typevars. The result + will not contain any constraints that mention those typevars. + """ + def implies_subtype_of(self, ty: Any, of: Any) -> Self: """ Returns a constraint set that is satisfied when `ty` is a `subtype`_ of From f22890bd88412e9d8b852849cb5a9eb8a2ae90fd Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Wed, 17 Dec 2025 17:06:20 -0500 Subject: [PATCH 03/16] add quantification test cases --- .../mdtest/generics/specialize_constrained.md | 45 ++- .../mdtest/type_properties/quantification.md | 279 ++++++++++++++++++ 2 files changed, 314 insertions(+), 10 deletions(-) create mode 100644 crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md diff --git a/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md b/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md index 32956cdfa8f6a..4f054dd1ac4ec 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md @@ -416,19 +416,44 @@ def mutually_bound[T: Base, U](): ## Nested typevars -A typevar's constraint can _mention_ another typevar without _constraining_ it. In this example, `U` -must be specialized to `list[T]`, but it cannot affect what `T` is specialized to. +The specialization of one typevar can affect the specialization of another, even if it is not a +"top-level" type in the bounds. (That is, if it appears as inside the specialization of a generic +class.) ```py from typing import Never from ty_extensions import ConstraintSet, generic_context -def mentions[T, U](): - # (T@mentions ≤ int) ∧ (U@mentions = list[T@mentions]) - constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(list[T], U, list[T]) - # TODO: revealed: ty_extensions.Specialization[T@mentions = int, U@mentions = list[int]] - # revealed: ty_extensions.Specialization[T@mentions = int, U@mentions = Unknown] - reveal_type(generic_context(mentions).specialize_constrained(constraints)) +class Covariant[T]: + def get(self) -> T: + raise NotImplementedError + +class Contravariant[T]: + def receive(self, input: T): ... + +class Invariant[T]: + mutable_attribute: T + +def mentions_covariant[T, U](): + # (T@mentions_covariant ≤ int) ∧ (U@mentions_covariant ≤ Covariant[T@mentions_covariant]) + constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(Never, U, Covariant[T]) + # TODO: revealed: ty_extensions.Specialization[T@mentions_covariant = int, U@mentions_covariant = Covariant[int]] + # revealed: ty_extensions.Specialization[T@mentions_covariant = int, U@mentions_covariant = Unknown] + reveal_type(generic_context(mentions_covariant).specialize_constrained(constraints)) + +def mentions_contravariant[T, U](): + # (T@mentions_contravariant ≤ int) ∧ (Contravariant[T@mentions_contravariant] ≤ U@mentions_contravariant) + constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(Contravariant[T], U, object) + # TODO: revealed: ty_extensions.Specialization[T@mentions_contravariant = int, U@mentions_contravariant = Contravariant[int]] + # revealed: ty_extensions.Specialization[T@mentions_contravariant = int, U@mentions_contravariant = Unknown] + reveal_type(generic_context(mentions_contravariant).specialize_constrained(constraints)) + +def mentions_invariant[T, U](): + # (T@mentions_invariant ≤ int) ∧ (U@mentions_invariant = Invariant[T@mentions_invariant]) + constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(Invariant[T], U, Invariant[T]) + # TODO: revealed: ty_extensions.Specialization[T@mentions_invariant = int, U@mentions_invariant = Invariant[int]] + # revealed: ty_extensions.Specialization[T@mentions_invariant = int, U@mentions_invariant = Unknown] + reveal_type(generic_context(mentions_invariant).specialize_constrained(constraints)) ``` If the constraint set contains mutually recursive bounds, specialization inference will not @@ -437,8 +462,8 @@ this case. ```py def divergent[T, U](): - # (T@divergent = list[U@divergent]) ∧ (U@divergent = list[T@divergent])) - constraints = ConstraintSet.range(list[U], T, list[U]) & ConstraintSet.range(list[T], U, list[T]) + # (T@divergent = Invariant[U@divergent]) ∧ (U@divergent = Invariant[T@divergent])) + constraints = ConstraintSet.range(Invariant[U], T, Invariant[U]) & ConstraintSet.range(Invariant[T], U, Invariant[T]) # revealed: None reveal_type(generic_context(divergent).specialize_constrained(constraints)) ``` diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md b/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md new file mode 100644 index 0000000000000..a09c656351446 --- /dev/null +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md @@ -0,0 +1,279 @@ +# Constraint set quantification + +```toml +[environment] +python-version = "3.12" +``` + +We can _existentially quantify_ a constraint set over a type variable. The result is a copy of the +constraint set that only mentions the requested typevar. All constraints mentioning any other +typevars are removed. Importantly, they are removed "safely", with their constraints propagated +through to the remaining constraints as needed. + +## Keeping a single typevar + +If a constraint set only mentions a single typevar, and we keep that typevar when quantifying, the +result is unchanged. + +```py +from ty_extensions import ConstraintSet, static_assert + +class Base: ... +class Sub(Base): ... + +def keep_single[T](): + constraints = ConstraintSet.always() + quantified = ConstraintSet.always() + static_assert(constraints.retain_one(T) == quantified) + + constraints = ConstraintSet.never() + quantified = ConstraintSet.never() + static_assert(constraints.retain_one(T) == quantified) + + constraints = ConstraintSet.range(Sub, T, Base) + quantified = ConstraintSet.range(Sub, T, Base) + static_assert(constraints.retain_one(T) == quantified) +``` + +## Removing a single typevar + +If a constraint set only mentions a single typevar, and we remove that typevar when quantifying, +the result is usually "always". The only exception is if the original constraint set has no +solution. In that case, the result is also unsatisfiable. + +```py +from ty_extensions import ConstraintSet, static_assert + +class Base: ... +class Sub(Base): ... + +def remove_single[T](): + constraints = ConstraintSet.always() + quantified = ConstraintSet.always() + static_assert(constraints.exists(T) == quantified) + + constraints = ConstraintSet.never() + quantified = ConstraintSet.never() + static_assert(constraints.exists(T) == quantified) + + constraints = ConstraintSet.range(Sub, T, Base) + quantified = ConstraintSet.always() + static_assert(constraints.exists(T) == quantified) +``` + +This also holds when the constraint set contains multiple typevars. In the cases below, we are +keeping `U`, and the constraints on `T` do not ever affect what `U` can specialize to — `U` can +specialize to anything (unless the original constraint set is unsatisfiable). + +```py +from ty_extensions import ConstraintSet, static_assert + +class Base: ... +class Sub(Base): ... + +def remove_other[T, U](): + constraints = ConstraintSet.always() + quantified = ConstraintSet.always() + static_assert(constraints.retain_one(U) == quantified) + + constraints = ConstraintSet.never() + quantified = ConstraintSet.never() + static_assert(constraints.retain_one(U) == quantified) + + constraints = ConstraintSet.range(Sub, T, Base) + quantified = ConstraintSet.always() + static_assert(constraints.retain_one(U) == quantified) +``` + +## Transitivity + +When a constraint set mentions two typevars, and compares them directly, then we can use +transitivity to propagate the other constraints when quantifying. + +```py +from typing import Never +from ty_extensions import ConstraintSet, static_assert + +class Super: ... +class Base(Super): ... +class Sub(Base): ... + +def transitivity[T, U](): + # (Base ≤ T) ∧ (T ≤ U) → (Base ≤ U) + constraints = ConstraintSet.range(Base, T, object) & ConstraintSet.range(T, U, object) + quantified = ConstraintSet.range(Base, U, object) + static_assert(constraints.exists(T) == quantified) + + # (Base ≤ T ≤ Super) ∧ (T ≤ U) → (Base ≤ U) + constraints = ConstraintSet.range(Base, T, Super) & ConstraintSet.range(T, U, object) + quantified = ConstraintSet.range(Base, U, object) + static_assert(constraints.exists(T) == quantified) + + # (T ≤ Base) ∧ (U ≤ T) → (U ≤ Base) + constraints = ConstraintSet.range(Never, T, Base) & ConstraintSet.range(Never, U, T) + quantified = ConstraintSet.range(Never, U, Base) + static_assert(constraints.exists(T) == quantified) + + # (Sub ≤ T ≤ Base) ∧ (U ≤ T) → (U ≤ Base) + constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Never, U, T) + quantified = ConstraintSet.range(Never, U, Base) + static_assert(constraints.exists(T) == quantified) +``` + +## Covariant transitivity + +The same applies when one of the typevars is used covariantly in a bound of the other typevar. + +```py +from typing import Never +from ty_extensions import ConstraintSet, static_assert + +class Super: ... +class Base(Super): ... +class Sub(Base): ... + +class Covariant[T]: + def get(self) -> T: + raise NotImplementedError + +def covariant_transitivity[T, U](): + # (Base ≤ T) ∧ (Covariant[T] ≤ U) → (Covariant[Base] ≤ U) + constraints = ConstraintSet.range(Base, T, object) & ConstraintSet.range(Covariant[T], U, object) + quantified = ConstraintSet.range(Covariant[Base], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (Base ≤ T ≤ Super) ∧ (Covariant[T] ≤ U) → (Covariant[Base] ≤ U) + constraints = ConstraintSet.range(Base, T, Super) & ConstraintSet.range(Covariant[T], U, object) + quantified = ConstraintSet.range(Covariant[Base], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T ≤ Base) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[Base]) + constraints = ConstraintSet.range(Never, T, Base) & ConstraintSet.range(Never, U, Covariant[T]) + quantified = ConstraintSet.range(Never, U, Covariant[Base]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (Sub ≤ T ≤ Base) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[Base]) + constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Never, U, Covariant[T]) + quantified = ConstraintSet.range(Never, U, Covariant[Base]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) +``` + +## Contravariant transitivity + +Similar rules apply, but in reverse, when one of the typevars is used contravariantly in a bound of +the other typevar. + +```py +from typing import Never +from ty_extensions import ConstraintSet, static_assert + +class Super: ... +class Base(Super): ... +class Sub(Base): ... + +class Contravariant[T]: + def receive(self, input: T): ... + +def contravariant_transitivity[T, U](): + # (Base ≤ T) ∧ (U ≤ Contravariant[T]) → (U ≤ Contravariant[Base]) + constraints = ConstraintSet.range(Base, T, object) & ConstraintSet.range(Never, U, Contravariant[T]) + quantified = ConstraintSet.range(Never, U, Contravariant[Base]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (Base ≤ T ≤ Super) ∧ (U ≤ Contravariant[T]) → (U ≤ Contravariant[Base]) + constraints = ConstraintSet.range(Base, T, Super) & ConstraintSet.range(Never, U, Contravariant[T]) + quantified = ConstraintSet.range(Never, U, Contravariant[Base]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T ≤ Base) ∧ (Contravariant[T] ≤ U) → (Contravariant[Base] ≤ U) + constraints = ConstraintSet.range(Never, T, Base) & ConstraintSet.range(Contravariant[T], U, object) + quantified = ConstraintSet.range(Contravariant[Base], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (Sub ≤ T ≤ Base) ∧ (Contravariant[T] ≤ U) → (Contravariant[Base] ≤ U) + constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Contravariant[T], U, object) + quantified = ConstraintSet.range(Contravariant[Base], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) +``` + +## Invariant transitivity involving equality constraints + +Invariant uses of a typevar are more subtle. The simplest case is when there is an _equality_ +constraint on the invariant typevar. In that case, we know precisely which specialization is +required. + +```py +from typing import Never +from ty_extensions import ConstraintSet, static_assert + +class Base: ... + +class Invariant[T]: + mutable_attribute: T + +def invariant_equality_transitivity[T, U](): + # (T = Base) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[Base]) + constraints = ConstraintSet.range(Base, T, Base) & ConstraintSet.range(Never, U, Invariant[T]) + quantified = ConstraintSet.range(Never, U, Invariant[Base]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T = Base) ∧ (Invariant[T] ≤ U) → (Invariant[Base] ≤ U) + constraints = ConstraintSet.range(Base, T, Base) & ConstraintSet.range(Invariant[T], U, object) + quantified = ConstraintSet.range(Invariant[Base], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) +``` + +## Invariant transitivity involving range constraints + +When there is a _range_ constraint on the invariant typevar, we still have to retain information +about which range of types the quantified-away typevar can specialize to, since this affects which +types the remaining typevar can specialize to, and invariant typevars are not monotonic like +covariant and contravariant typevars. + +```py +from typing import Never +from ty_extensions import ConstraintSet, static_assert + +class Base: ... +class Sub(Base): ... + +class Invariant[T]: + mutable_attribute: T + +def invariant_range_transitivity[T, U](): + # (Sub ≤ T ≤ Base) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[Exists[Sub, Base]]) + constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Never, U, Invariant[T]) + # TODO: The existential that we need doesn't exist yet. + quantified = ConstraintSet.never() + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (Sub ≤ T ≤ Base) ∧ (Invariant[T] ≤ U) → (Invariant[Exists[Sub, Base]] ≤ U) + constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Invariant[T], U, object) + # TODO: The existential that we need doesn't exist yet. + quantified = ConstraintSet.never() + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) +``` From 7d38dc8685277461aa59e3c5cd2fe3126811f3cd Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Wed, 17 Dec 2025 17:59:40 -0500 Subject: [PATCH 04/16] restructure a bit --- .../src/types/constraints.rs | 61 ++++++++++++------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 58058db1bd15f..88889601f4bd3 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -3050,29 +3050,41 @@ impl<'db> SequentMap<'db> { left_constraint: ConstrainedTypeVar<'db>, right_constraint: ConstrainedTypeVar<'db>, ) { - // We've structured our constraints so that a typevar's upper/lower bound can only - // be another typevar if the bound is "later" in our arbitrary ordering. That means - // we only have to check this pair of constraints in one direction — though we do - // have to figure out which of the two typevars is constrained, and which one is - // the upper/lower bound. + // Add sequents when the two typevars are mutually constrained directly — that is, when the + // lower/upper bound _is_ the typevar, not _contains_ the typevar. We've structured our + // constraints so that a typevar's upper/lower bound can only be another typevar if the + // bound is "later" in our arbitrary ordering. That means we only have to check this pair + // of constraints in one direction — though we do have to figure out which of the two + // typevars is constrained, and which one is the upper/lower bound. let left_typevar = left_constraint.typevar(db); let right_typevar = right_constraint.typevar(db); - let (bound_typevar, bound_constraint, constrained_typevar, constrained_constraint) = - if left_typevar.can_be_bound_for(db, right_typevar) { - ( - left_typevar, - left_constraint, - right_typevar, - right_constraint, - ) - } else { - ( - right_typevar, - right_constraint, - left_typevar, - left_constraint, - ) - }; + if left_typevar.can_be_bound_for(db, right_typevar) { + self.add_direct_mutual_sequents_for_different_typevars( + db, + left_constraint, + left_typevar, + right_constraint, + right_typevar, + ); + } else { + self.add_direct_mutual_sequents_for_different_typevars( + db, + right_constraint, + right_typevar, + left_constraint, + left_typevar, + ); + } + } + + fn add_direct_mutual_sequents_for_different_typevars( + &mut self, + db: &'db dyn Db, + bound_constraint: ConstrainedTypeVar<'db>, + bound_typevar: BoundTypeVarInstance<'db>, + constrained_constraint: ConstrainedTypeVar<'db>, + constrained_typevar: BoundTypeVarInstance<'db>, + ) { // We then look for cases where the "constrained" typevar's upper and/or lower bound // matches the "bound" typevar. If so, we're going to add an implication sequent that @@ -3127,7 +3139,12 @@ impl<'db> SequentMap<'db> { let post_constraint = ConstrainedTypeVar::new(db, constrained_typevar, new_lower, new_upper); - self.add_pair_implication(db, left_constraint, right_constraint, post_constraint); + self.add_pair_implication( + db, + bound_constraint, + constrained_constraint, + post_constraint, + ); self.enqueue_constraint(post_constraint); } From 566bfc5ecbe37ecaaa11f18e1f129dc57d9df6bf Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Wed, 17 Dec 2025 19:23:56 -0500 Subject: [PATCH 05/16] partial spec is an enum --- .../ty_python_semantic/src/types/generics.rs | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index ca6a700bd240f..baa991e333e99 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -568,7 +568,7 @@ impl<'db> GenericContext<'db> { loop { let mut any_changed = false; for i in 0..len { - let partial = PartialSpecialization { + let partial = PartialSpecialization::FromGenericContext { generic_context: self, types: &types, // Don't recursively substitute type[i] in itself. Ideally, we could instead @@ -646,7 +646,7 @@ impl<'db> GenericContext<'db> { // Typevars are only allowed to refer to _earlier_ typevars in their defaults. (This is // statically enforced for PEP-695 contexts, and is explicitly called out as a // requirement for legacy contexts.) - let partial = PartialSpecialization { + let partial = PartialSpecialization::FromGenericContext { generic_context: self, types: &expanded[0..idx], skip: None, @@ -1452,12 +1452,14 @@ impl<'db> Specialization<'db> { /// You will usually use [`Specialization`] instead of this type. This type is used when we need to /// substitute types for type variables before we have fully constructed a [`Specialization`]. #[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize)] -pub struct PartialSpecialization<'a, 'db> { - generic_context: GenericContext<'db>, - types: &'a [Type<'db>], - /// An optional typevar to _not_ substitute when applying the specialization. We use this to - /// avoid recursively substituting a type inside of itself. - skip: Option, +pub enum PartialSpecialization<'a, 'db> { + FromGenericContext { + generic_context: GenericContext<'db>, + types: &'a [Type<'db>], + /// An optional typevar to _not_ substitute when applying the specialization. We use this to + /// avoid recursively substituting a type inside of itself. + skip: Option, + }, } impl<'db> PartialSpecialization<'_, 'db> { @@ -1468,14 +1470,21 @@ impl<'db> PartialSpecialization<'_, 'db> { db: &'db dyn Db, bound_typevar: BoundTypeVarInstance<'db>, ) -> Option> { - let index = self - .generic_context - .variables_inner(db) - .get_index_of(&bound_typevar.identity(db))?; - if self.skip.is_some_and(|skip| skip == index) { - return Some(Type::Never); + match self { + PartialSpecialization::FromGenericContext { + generic_context, + types, + skip, + } => { + let index = generic_context + .variables_inner(db) + .get_index_of(&bound_typevar.identity(db))?; + if skip.is_some_and(|skip| skip == index) { + return Some(Type::Never); + } + types.get(index).copied() + } } - self.types.get(index).copied() } } From 61139355788ee5a988a475df89ee166971b902fd Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Wed, 17 Dec 2025 19:25:18 -0500 Subject: [PATCH 06/16] partial specialize single typevar --- crates/ty_python_semantic/src/types/generics.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index baa991e333e99..ad7529eb2244d 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1460,6 +1460,10 @@ pub enum PartialSpecialization<'a, 'db> { /// avoid recursively substituting a type inside of itself. skip: Option, }, + Single { + bound_typevar: BoundTypeVarInstance<'db>, + ty: Type<'db>, + }, } impl<'db> PartialSpecialization<'_, 'db> { @@ -1468,7 +1472,7 @@ impl<'db> PartialSpecialization<'_, 'db> { pub(crate) fn get( &self, db: &'db dyn Db, - bound_typevar: BoundTypeVarInstance<'db>, + needle_bound_typevar: BoundTypeVarInstance<'db>, ) -> Option> { match self { PartialSpecialization::FromGenericContext { @@ -1478,12 +1482,19 @@ impl<'db> PartialSpecialization<'_, 'db> { } => { let index = generic_context .variables_inner(db) - .get_index_of(&bound_typevar.identity(db))?; + .get_index_of(&needle_bound_typevar.identity(db))?; if skip.is_some_and(|skip| skip == index) { return Some(Type::Never); } types.get(index).copied() } + PartialSpecialization::Single { bound_typevar, ty } => { + if bound_typevar.is_same_typevar_as(db, needle_bound_typevar) { + Some(*ty) + } else { + None + } + } } } } From 026b753735f6550a1a7b566fb59f4f9038b89540 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Wed, 17 Dec 2025 19:35:17 -0500 Subject: [PATCH 07/16] handle the covariant case --- .../mdtest/generics/specialize_constrained.md | 3 +- .../mdtest/type_properties/quantification.md | 8 -- .../src/types/constraints.rs | 124 +++++++++++++++++- 3 files changed, 121 insertions(+), 14 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md b/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md index 4f054dd1ac4ec..5c37e118bfb8d 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md @@ -437,8 +437,7 @@ class Invariant[T]: def mentions_covariant[T, U](): # (T@mentions_covariant ≤ int) ∧ (U@mentions_covariant ≤ Covariant[T@mentions_covariant]) constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(Never, U, Covariant[T]) - # TODO: revealed: ty_extensions.Specialization[T@mentions_covariant = int, U@mentions_covariant = Covariant[int]] - # revealed: ty_extensions.Specialization[T@mentions_covariant = int, U@mentions_covariant = Unknown] + # revealed: ty_extensions.Specialization[T@mentions_covariant = int, U@mentions_covariant = Covariant[int]] reveal_type(generic_context(mentions_covariant).specialize_constrained(constraints)) def mentions_contravariant[T, U](): diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md b/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md index a09c656351446..74386973f261b 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md @@ -140,29 +140,21 @@ def covariant_transitivity[T, U](): # (Base ≤ T) ∧ (Covariant[T] ≤ U) → (Covariant[Base] ≤ U) constraints = ConstraintSet.range(Base, T, object) & ConstraintSet.range(Covariant[T], U, object) quantified = ConstraintSet.range(Covariant[Base], U, object) - # TODO: no error - # error: [static-assert-error] static_assert(constraints.exists(T) == quantified) # (Base ≤ T ≤ Super) ∧ (Covariant[T] ≤ U) → (Covariant[Base] ≤ U) constraints = ConstraintSet.range(Base, T, Super) & ConstraintSet.range(Covariant[T], U, object) quantified = ConstraintSet.range(Covariant[Base], U, object) - # TODO: no error - # error: [static-assert-error] static_assert(constraints.exists(T) == quantified) # (T ≤ Base) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[Base]) constraints = ConstraintSet.range(Never, T, Base) & ConstraintSet.range(Never, U, Covariant[T]) quantified = ConstraintSet.range(Never, U, Covariant[Base]) - # TODO: no error - # error: [static-assert-error] static_assert(constraints.exists(T) == quantified) # (Sub ≤ T ≤ Base) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[Base]) constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Never, U, Covariant[T]) quantified = ConstraintSet.range(Never, U, Covariant[Base]) - # TODO: no error - # error: [static-assert-error] static_assert(constraints.exists(T) == quantified) ``` diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 88889601f4bd3..dc78203e0319b 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -75,13 +75,16 @@ use itertools::Itertools; use rustc_hash::{FxHashMap, FxHashSet}; use salsa::plumbing::AsId; -use crate::types::generics::{GenericContext, InferableTypeVars, Specialization}; +use crate::types::generics::{ + GenericContext, InferableTypeVars, PartialSpecialization, Specialization, +}; +use crate::types::variance::VarianceInferable; use crate::types::visitor::{ TypeCollector, TypeVisitor, any_over_type, walk_type_with_recursion_guard, }; use crate::types::{ - BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeVarBoundOrConstraints, - UnionType, walk_bound_type_var_type, + BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeContext, TypeMapping, + TypeVarBoundOrConstraints, TypeVarVariance, UnionType, walk_bound_type_var_type, }; use crate::{Db, FxOrderMap}; @@ -3075,6 +3078,24 @@ impl<'db> SequentMap<'db> { left_typevar, ); } + + // Add sequents when one of the typevars is mentioned deeply inside the bounds of the + // other. Because the "bound" typevar appears deeply inside of the "constrained" typevar, + // our constraint ordering doesn't apply, and we have to try each direction. + self.add_deep_mutual_sequents( + db, + left_constraint, + left_typevar, + right_constraint, + right_typevar, + ); + self.add_deep_mutual_sequents( + db, + right_constraint, + right_typevar, + left_constraint, + left_typevar, + ); } fn add_direct_mutual_sequents_for_different_typevars( @@ -3085,7 +3106,6 @@ impl<'db> SequentMap<'db> { constrained_constraint: ConstrainedTypeVar<'db>, constrained_typevar: BoundTypeVarInstance<'db>, ) { - // We then look for cases where the "constrained" typevar's upper and/or lower bound // matches the "bound" typevar. If so, we're going to add an implication sequent that // replaces the upper/lower bound that matched with the bound constraint's corresponding @@ -3148,6 +3168,102 @@ impl<'db> SequentMap<'db> { self.enqueue_constraint(post_constraint); } + fn add_deep_mutual_sequents( + &mut self, + db: &'db dyn Db, + bound_constraint: ConstrainedTypeVar<'db>, + bound_typevar: BoundTypeVarInstance<'db>, + constrained_constraint: ConstrainedTypeVar<'db>, + constrained_typevar: BoundTypeVarInstance<'db>, + ) { + let bound_lower = bound_constraint.lower(db); + let bound_upper = bound_constraint.upper(db); + let constrained_lower = constrained_constraint.lower(db); + let constrained_upper = constrained_constraint.upper(db); + + let post_constraint = match constrained_lower.variance_of(db, bound_typevar) { + // B does not appear in CU, or if it does, it appears bivariantly. The constraints of B + // do not affect the valid specializations of C. + TypeVarVariance::Bivariant => None, + + // (Covariant[B] ≤ C ≤ CU) ∧ (BL ≤ B ≤ BU) → (Covariant[BL] ≤ C ≤ CU) + TypeVarVariance::Covariant => { + if !bound_lower.is_never() && !bound_lower.is_object() { + let partial = PartialSpecialization::Single { + bound_typevar, + ty: bound_lower, + }; + let new_lower = constrained_lower.apply_type_mapping( + db, + &TypeMapping::PartialSpecialization(partial), + TypeContext::default(), + ); + Some(ConstrainedTypeVar::new( + db, + constrained_typevar, + new_lower, + constrained_upper, + )) + } else { + None + } + } + + // TODO + TypeVarVariance::Contravariant | TypeVarVariance::Invariant => None, + }; + if let Some(post_constraint) = post_constraint { + self.add_pair_implication( + db, + bound_constraint, + constrained_constraint, + post_constraint, + ); + self.enqueue_constraint(post_constraint); + } + + let post_constraint = match constrained_upper.variance_of(db, bound_typevar) { + // B does not appear in CU, or if it does, it appears bivariantly. The constraints of B + // do not affect the valid specializations of C. + TypeVarVariance::Bivariant => None, + + // (CL ≤ C ≤ Covariant[B]) ∧ (BL ≤ B ≤ BU) → (CL ≤ C ≤ Covariant[BU]) + TypeVarVariance::Covariant => { + if !bound_upper.is_never() && !bound_upper.is_object() { + let partial = PartialSpecialization::Single { + bound_typevar, + ty: bound_upper, + }; + let new_upper = constrained_upper.apply_type_mapping( + db, + &TypeMapping::PartialSpecialization(partial), + TypeContext::default(), + ); + Some(ConstrainedTypeVar::new( + db, + constrained_typevar, + constrained_lower, + new_upper, + )) + } else { + None + } + } + + // TODO + TypeVarVariance::Contravariant | TypeVarVariance::Invariant => None, + }; + if let Some(post_constraint) = post_constraint { + self.add_pair_implication( + db, + bound_constraint, + constrained_constraint, + post_constraint, + ); + self.enqueue_constraint(post_constraint); + } + } + fn add_mutual_sequents_for_same_typevars( &mut self, db: &'db dyn Db, From 804842ac80f9f808fa2fa9178dd818099305a10d Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Wed, 17 Dec 2025 21:07:16 -0500 Subject: [PATCH 08/16] failing tests for propagating typevars --- .../mdtest/type_properties/quantification.md | 140 ++++++++++++++++++ 1 file changed, 140 insertions(+) diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md b/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md index 74386973f261b..44430fcb05cf9 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md @@ -158,6 +158,37 @@ def covariant_transitivity[T, U](): static_assert(constraints.exists(T) == quantified) ``` +Same as above, but when propagating a third typevar instead of a concrete type. We make sure to test +with both variable orderings for the constraint that involves two typevars. + +```py +def covariant_typevar_transitivity[B, T, U](): + # (B ≤ T) ∧ (Covariant[T] ≤ U) → (Covariant[B] ≤ U) + constraints = ConstraintSet.range(B, T, object) & ConstraintSet.range(Covariant[T], U, object) + quantified = ConstraintSet.range(Covariant[B], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T ≤ B) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[B]) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Never, U, Covariant[T]) + quantified = ConstraintSet.range(Never, U, Covariant[B]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + +def covariant_typevar_transitivity_reversed[T, B, U](): + # (B ≤ T) ∧ (Covariant[T] ≤ U) → (Covariant[B] ≤ U) + constraints = ConstraintSet.range(B, T, object) & ConstraintSet.range(Covariant[T], U, object) + quantified = ConstraintSet.range(Covariant[B], U, object) + static_assert(constraints.exists(T) == quantified) + + # (T ≤ B) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[B]) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Never, U, Covariant[T]) + quantified = ConstraintSet.range(Never, U, Covariant[B]) + static_assert(constraints.exists(T) == quantified) +``` + ## Contravariant transitivity Similar rules apply, but in reverse, when one of the typevars is used contravariantly in a bound of @@ -204,6 +235,41 @@ def contravariant_transitivity[T, U](): static_assert(constraints.exists(T) == quantified) ``` +Same as above, but when propagating a third typevar instead of a concrete type. We make sure to test +with both variable orderings for the constraint that involves two typevars. + +```py +def contravariant_typevar_transitivity[B, T, U](): + # (B ≤ T) ∧ (U ≤ Contravariant[T]) → (U ≤ Contravariant[B]) + constraints = ConstraintSet.range(B, T, object) & ConstraintSet.range(Never, U, Contravariant[T]) + quantified = ConstraintSet.range(Never, U, Contravariant[B]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T ≤ B) ∧ (Contravariant[T] ≤ U) → (Contravariant[B] ≤ U) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Contravariant[T], U, object) + quantified = ConstraintSet.range(Contravariant[B], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + +def contravariant_typevar_transitivity_reversed[T, B, U](): + # (B ≤ T) ∧ (U ≤ Contravariant[T]) → (U ≤ Contravariant[B]) + constraints = ConstraintSet.range(B, T, object) & ConstraintSet.range(Never, U, Contravariant[T]) + quantified = ConstraintSet.range(Never, U, Contravariant[B]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T ≤ B) ∧ (Contravariant[T] ≤ U) → (Contravariant[B] ≤ U) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Contravariant[T], U, object) + quantified = ConstraintSet.range(Contravariant[B], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) +``` + ## Invariant transitivity involving equality constraints Invariant uses of a typevar are more subtle. The simplest case is when there is an _equality_ @@ -235,6 +301,41 @@ def invariant_equality_transitivity[T, U](): static_assert(constraints.exists(T) == quantified) ``` +Same as above, but when propagating a third typevar instead of a concrete type. We make sure to test +with both variable orderings for the constraint that involves two typevars. + +```py +def invariant_equality_typevar_transitivity[B, T, U](): + # (T = B) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[B]) + constraints = ConstraintSet.range(B, T, B) & ConstraintSet.range(Never, U, Invariant[T]) + quantified = ConstraintSet.range(Never, U, Invariant[B]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T = B) ∧ (Invariant[T] ≤ U) → (Invariant[B] ≤ U) + constraints = ConstraintSet.range(B, T, B) & ConstraintSet.range(Invariant[T], U, object) + quantified = ConstraintSet.range(Invariant[B], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + +def invariant_equality_typevar_transitivity_reverse[T, B, U](): + # (T = B) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[B]) + constraints = ConstraintSet.range(B, T, B) & ConstraintSet.range(Never, U, Invariant[T]) + quantified = ConstraintSet.range(Never, U, Invariant[B]) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T = B) ∧ (Invariant[T] ≤ U) → (Invariant[B] ≤ U) + constraints = ConstraintSet.range(B, T, B) & ConstraintSet.range(Invariant[T], U, object) + quantified = ConstraintSet.range(Invariant[B], U, object) + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) +``` + ## Invariant transitivity involving range constraints When there is a _range_ constraint on the invariant typevar, we still have to retain information @@ -269,3 +370,42 @@ def invariant_range_transitivity[T, U](): # error: [static-assert-error] static_assert(constraints.exists(T) == quantified) ``` + +Same as above, but when propagating a third typevar instead of a concrete type. We make sure to test +with both variable orderings for the constraint that involves two typevars. + +```py +def invariant_range_typevar_transitivity[B, T, U](): + # (T ≤ B) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[Exists[Never, B]]) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Never, U, Invariant[T]) + # TODO: The existential that we need doesn't exist yet. + quantified = ConstraintSet.never() + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T ≤ B) ∧ (Invariant[T] ≤ U) → (Invariant[Exists[Never, B]] ≤ U) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Invariant[T], U, object) + # TODO: The existential that we need doesn't exist yet. + quantified = ConstraintSet.never() + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + +def invariant_range_typevar_transitivity_reverse[T, B, U](): + # (T ≤ B) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[Exists[Never, B]]) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Never, U, Invariant[T]) + # TODO: The existential that we need doesn't exist yet. + quantified = ConstraintSet.never() + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) + + # (T ≤ B) ∧ (Invariant[T] ≤ U) → (Invariant[Exists[Never, B]] ≤ U) + constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Invariant[T], U, object) + # TODO: The existential that we need doesn't exist yet. + quantified = ConstraintSet.never() + # TODO: no error + # error: [static-assert-error] + static_assert(constraints.exists(T) == quantified) +``` From 056414bfacb85e79a40850f821af1b5321cd6b17 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Thu, 18 Dec 2025 09:31:52 -0500 Subject: [PATCH 09/16] typevar ordering strikes again --- .../mdtest/type_properties/quantification.md | 4 - .../src/types/constraints.rs | 94 +++++++++++++------ 2 files changed, 63 insertions(+), 35 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md b/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md index 44430fcb05cf9..70b6b6f3d3a8c 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md @@ -166,15 +166,11 @@ def covariant_typevar_transitivity[B, T, U](): # (B ≤ T) ∧ (Covariant[T] ≤ U) → (Covariant[B] ≤ U) constraints = ConstraintSet.range(B, T, object) & ConstraintSet.range(Covariant[T], U, object) quantified = ConstraintSet.range(Covariant[B], U, object) - # TODO: no error - # error: [static-assert-error] static_assert(constraints.exists(T) == quantified) # (T ≤ B) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[B]) constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Never, U, Covariant[T]) quantified = ConstraintSet.range(Never, U, Covariant[B]) - # TODO: no error - # error: [static-assert-error] static_assert(constraints.exists(T) == quantified) def covariant_typevar_transitivity_reversed[T, B, U](): diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index dc78203e0319b..31c437b0a67d9 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -3181,10 +3181,20 @@ impl<'db> SequentMap<'db> { let constrained_lower = constrained_constraint.lower(db); let constrained_upper = constrained_constraint.upper(db); - let post_constraint = match constrained_lower.variance_of(db, bound_typevar) { + let mut add_constraint = |post_constraint| { + self.add_pair_implication( + db, + bound_constraint, + constrained_constraint, + post_constraint, + ); + self.enqueue_constraint(post_constraint); + }; + + match constrained_lower.variance_of(db, bound_typevar) { // B does not appear in CU, or if it does, it appears bivariantly. The constraints of B // do not affect the valid specializations of C. - TypeVarVariance::Bivariant => None, + TypeVarVariance::Bivariant => {} // (Covariant[B] ≤ C ≤ CU) ∧ (BL ≤ B ≤ BU) → (Covariant[BL] ≤ C ≤ CU) TypeVarVariance::Covariant => { @@ -3198,34 +3208,45 @@ impl<'db> SequentMap<'db> { &TypeMapping::PartialSpecialization(partial), TypeContext::default(), ); - Some(ConstrainedTypeVar::new( + add_constraint(ConstrainedTypeVar::new( db, constrained_typevar, new_lower, constrained_upper, - )) - } else { - None + )); } } // TODO - TypeVarVariance::Contravariant | TypeVarVariance::Invariant => None, + TypeVarVariance::Contravariant | TypeVarVariance::Invariant => {} }; - if let Some(post_constraint) = post_constraint { - self.add_pair_implication( - db, - bound_constraint, - constrained_constraint, - post_constraint, - ); - self.enqueue_constraint(post_constraint); + + // (Covariant[BU] ≤ C ≤ CU) ∧ (BL ≤ B ≤ BU) → (Covariant[B] ≤ C ≤ CU) + if let Type::TypeVar(bound_upper_typevar) = bound_upper { + if constrained_lower.variance_of(db, bound_upper_typevar) == TypeVarVariance::Covariant + { + let partial = PartialSpecialization::Single { + bound_typevar: bound_upper_typevar, + ty: Type::TypeVar(bound_typevar), + }; + let new_lower = constrained_lower.apply_type_mapping( + db, + &TypeMapping::PartialSpecialization(partial), + TypeContext::default(), + ); + add_constraint(ConstrainedTypeVar::new( + db, + constrained_typevar, + new_lower, + constrained_upper, + )); + } } - let post_constraint = match constrained_upper.variance_of(db, bound_typevar) { + match constrained_upper.variance_of(db, bound_typevar) { // B does not appear in CU, or if it does, it appears bivariantly. The constraints of B // do not affect the valid specializations of C. - TypeVarVariance::Bivariant => None, + TypeVarVariance::Bivariant => {} // (CL ≤ C ≤ Covariant[B]) ∧ (BL ≤ B ≤ BU) → (CL ≤ C ≤ Covariant[BU]) TypeVarVariance::Covariant => { @@ -3239,28 +3260,39 @@ impl<'db> SequentMap<'db> { &TypeMapping::PartialSpecialization(partial), TypeContext::default(), ); - Some(ConstrainedTypeVar::new( + add_constraint(ConstrainedTypeVar::new( db, constrained_typevar, constrained_lower, new_upper, - )) - } else { - None + )); } } // TODO - TypeVarVariance::Contravariant | TypeVarVariance::Invariant => None, - }; - if let Some(post_constraint) = post_constraint { - self.add_pair_implication( - db, - bound_constraint, - constrained_constraint, - post_constraint, - ); - self.enqueue_constraint(post_constraint); + TypeVarVariance::Contravariant | TypeVarVariance::Invariant => {} + } + + // (CL ≤ C ≤ Covariant[BL]) ∧ (BL ≤ B ≤ BU) → (CL ≤ C ≤ Covariant[B]) + if let Type::TypeVar(bound_lower_typevar) = bound_lower { + if constrained_upper.variance_of(db, bound_lower_typevar) == TypeVarVariance::Covariant + { + let partial = PartialSpecialization::Single { + bound_typevar: bound_lower_typevar, + ty: Type::TypeVar(bound_typevar), + }; + let new_upper = constrained_upper.apply_type_mapping( + db, + &TypeMapping::PartialSpecialization(partial), + TypeContext::default(), + ); + add_constraint(ConstrainedTypeVar::new( + db, + constrained_typevar, + constrained_lower, + new_upper, + )); + } } } From 3cc77a626db9adf213139f6cd1c0b7ad40a97b76 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Wed, 17 Dec 2025 21:37:25 -0500 Subject: [PATCH 10/16] update higher-order callable tests --- .../mdtest/generics/pep695/functions.md | 81 ++++++++++++++++--- 1 file changed, 72 insertions(+), 9 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index 3cdebe848e561..43b799f319002 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -481,22 +481,85 @@ reveal_type(g(f("a"))) # revealed: tuple[Literal["a"], int] | None ## Passing generic functions to generic functions -```py +`functions.pyi`: + +```pyi from typing import Callable -def invoke[A, B](fn: Callable[[A], B], value: A) -> B: - return fn(value) +def invoke[A, B](fn: Callable[[A], B], value: A) -> B: ... -def identity[T](x: T) -> T: - return x +def identity[T](x: T) -> T: ... + +class Covariant[T]: + def get(self) -> T: + raise NotImplementedError -def head[T](xs: list[T]) -> T: - return xs[0] +def head_covariant[T](xs: Covariant[T]) -> T: ... +def lift_covariant[T](xs: T) -> Covariant[T]: ... + +class Contravariant[T]: + def receive(self, input: T): ... + +def head_contravariant[T](xs: Contravariant[T]) -> T: ... +def lift_contravariant[T](xs: T) -> Contravariant[T]: ... + +class Invariant[T]: + mutable_attribute: T + +def head_invariant[T](xs: Invariant[T]) -> T: ... +def lift_invariant[T](xs: T) -> Invariant[T]: ... +``` + +A simple function that passes through its parameter type unchanged: + +`simple.py`: + +```py +from functions import invoke, identity reveal_type(invoke(identity, 1)) # revealed: Literal[1] +``` + +When the either the parameter or the return type is a generic alias referring to the typevar, we +should still be able to propagate the specializations through. This should work regardless of the +typevar's variance in the generic alias. + +`covariant.py`: -# TODO: this should be `Unknown | int` -reveal_type(invoke(head, [1, 2, 3])) # revealed: Unknown +```py +from functions import invoke, Covariant, head_covariant, lift_covariant + +# TODO: revealed: `int` +# revealed: Unknown +reveal_type(invoke(head_covariant, Covariant[int]())) +# revealed: Covariant[Literal[1]] +reveal_type(invoke(lift_covariant, 1)) +``` + +`contravariant.py`: + +```py +from functions import invoke, Contravariant, head_contravariant, lift_contravariant + +# TODO: revealed: `int` +# revealed: Unknown +reveal_type(invoke(head_contravariant, Contravariant[int]())) +# TODO: revealed: Contravariant[int] +# revealed: Unknown +reveal_type(invoke(lift_contravariant, 1)) +``` + +`invariant.py`: + +```py +from functions import invoke, Invariant, head_invariant, lift_invariant + +# TODO: revealed: `int` +# revealed: Unknown +reveal_type(invoke(head_invariant, Invariant[int]())) +# TODO: revealed: `Invariant[int]` +# revealed: Unknown +reveal_type(invoke(lift_invariant, 1)) ``` ## Protocols as TypeVar bounds From a694ede51d1203a99ae05fdd9afe5b12dde04811 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Thu, 18 Dec 2025 11:11:43 -0500 Subject: [PATCH 11/16] clippy --- crates/ty_python_semantic/src/types/constraints.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 31c437b0a67d9..8a8a9ee664d87 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -3219,7 +3219,7 @@ impl<'db> SequentMap<'db> { // TODO TypeVarVariance::Contravariant | TypeVarVariance::Invariant => {} - }; + } // (Covariant[BU] ≤ C ≤ CU) ∧ (BL ≤ B ≤ BU) → (Covariant[B] ≤ C ≤ CU) if let Type::TypeVar(bound_upper_typevar) = bound_upper { From a35c733b64af0a3ed1a73edcb97d3b64535f5292 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Thu, 18 Dec 2025 11:16:44 -0500 Subject: [PATCH 12/16] clean up callback types --- .../ty_python_semantic/src/types/generics.rs | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index ad7529eb2244d..b8d8283667e81 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1560,7 +1560,7 @@ impl<'db> SpecializationBuilder<'db> { bound_typevar: BoundTypeVarInstance<'db>, ty: Type<'db>, variance: TypeVarVariance, - mut f: impl FnMut(TypeVarAssignment<'db>) -> Option>, + f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option>, ) { let identity = bound_typevar.identity(self.db); let Some(ty) = f((identity, variance, ty)) else { @@ -1602,7 +1602,7 @@ impl<'db> SpecializationBuilder<'db> { &mut self, formal: Type<'db>, constraints: ConstraintSet<'db>, - mut f: impl FnMut(TypeVarAssignment<'db>) -> Option>, + f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option>, ) { #[derive(Default)] struct Bounds<'db> { @@ -1649,12 +1649,12 @@ impl<'db> SpecializationBuilder<'db> { let variance = formal.variance_of(self.db, bound_typevar); let upper = IntersectionType::from_elements(self.db, bounds.upper); if !upper.is_object() { - self.add_type_mapping(bound_typevar, upper, variance, &mut f); + self.add_type_mapping(bound_typevar, upper, variance, f); continue; } let lower = UnionType::from_elements(self.db, bounds.lower); if !lower.is_never() { - self.add_type_mapping(bound_typevar, lower, variance, &mut f); + self.add_type_mapping(bound_typevar, lower, variance, f); } } } @@ -1687,7 +1687,7 @@ impl<'db> SpecializationBuilder<'db> { formal: Type<'db>, actual: Type<'db>, polarity: TypeVarVariance, - mut f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option>, + f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option>, ) -> Result<(), SpecializationError<'db>> { // TODO: Eventually, the builder will maintain a constraint set, instead of a hash-map of // type mappings, to represent the specialization that we are building up. At that point, @@ -1799,7 +1799,7 @@ impl<'db> SpecializationBuilder<'db> { let mut first_error = None; let mut found_matching_element = false; for formal_element in union_formal.elements(self.db) { - let result = self.infer_map_impl(*formal_element, actual, polarity, &mut f); + let result = self.infer_map_impl(*formal_element, actual, polarity, f); if let Err(err) = result { first_error.get_or_insert(err); } else { @@ -1902,7 +1902,7 @@ impl<'db> SpecializationBuilder<'db> { formal_tuple.all_elements().zip(actual_tuple.all_elements()) { let variance = TypeVarVariance::Covariant.compose(polarity); - self.infer_map_impl(*formal_element, *actual_element, variance, &mut f)?; + self.infer_map_impl(*formal_element, *actual_element, variance, f)?; } return Ok(()); } @@ -1945,7 +1945,7 @@ impl<'db> SpecializationBuilder<'db> { base_specialization ) { let variance = typevar.variance_with_polarity(self.db, polarity); - self.infer_map_impl(*formal_ty, *base_ty, variance, &mut f)?; + self.infer_map_impl(*formal_ty, *base_ty, variance, f)?; } return Ok(()); } @@ -1969,7 +1969,7 @@ impl<'db> SpecializationBuilder<'db> { formal_callable, self.inferable, ); - self.add_type_mappings_from_constraint_set(formal, when, &mut f); + self.add_type_mappings_from_constraint_set(formal, when, f); } else { for actual_signature in &actual_callable.signatures(self.db).overloads { let when = actual_signature @@ -1978,7 +1978,7 @@ impl<'db> SpecializationBuilder<'db> { formal_callable, self.inferable, ); - self.add_type_mappings_from_constraint_set(formal, when, &mut f); + self.add_type_mappings_from_constraint_set(formal, when, f); } } } From 3d0dea21adb10d5371b71974242d51f0acf1d922 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Thu, 18 Dec 2025 11:17:04 -0500 Subject: [PATCH 13/16] use or_else here, it's cleaner --- .../ty_python_semantic/src/types/generics.rs | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index b8d8283667e81..3b728ee4578c0 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1646,16 +1646,20 @@ impl<'db> SpecializationBuilder<'db> { } for (bound_typevar, bounds) in mappings.drain() { - let variance = formal.variance_of(self.db, bound_typevar); - let upper = IntersectionType::from_elements(self.db, bounds.upper); - if !upper.is_object() { - self.add_type_mapping(bound_typevar, upper, variance, f); + let try_upper = || { + let upper = IntersectionType::from_elements(self.db, bounds.upper); + (!upper.is_object()).then_some(upper) + }; + let try_lower = || { + let lower = UnionType::from_elements(self.db, bounds.lower); + (!lower.is_never()).then_some(lower) + }; + let Some(mapped_type) = try_upper().or_else(try_lower) else { continue; - } - let lower = UnionType::from_elements(self.db, bounds.lower); - if !lower.is_never() { - self.add_type_mapping(bound_typevar, lower, variance, f); - } + }; + + let variance = formal.variance_of(self.db, bound_typevar); + self.add_type_mapping(bound_typevar, mapped_type, variance, f); } } } From 15c9d66ef6d3fe773a63d174848e53337a357fb5 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Thu, 18 Dec 2025 20:56:20 -0500 Subject: [PATCH 14/16] document missing case --- .../resources/mdtest/generics/pep695/functions.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index 43b799f319002..9c67fd6cba18a 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -524,6 +524,20 @@ When the either the parameter or the return type is a generic alias referring to should still be able to propagate the specializations through. This should work regardless of the typevar's variance in the generic alias. +TODO: This currently only works for the `lift` functions (TODO: and only currently for the covariant +case). For the `lift` functions, the parameter type is a bare typevar, resulting in us inferring a +type mapping of `A = int, B = Class[A]`. When specializing, we can substitute the mapping of `A` +into the mapping of `B`, giving the correct return type. + +For the `head` functions, the parameter type is a generic alias, resulting in us inferring a type +mapping of `A = Class[int], A = Class[B]`. At this point, the old solver is not able to unify the +two mappings for `A`, and we have no mapping for `B`. As a result, we infer `Unknown` for the return +type. + +As part of migrating to the new solver, we will generate a single constraint set combining all of +the facts that we learn while checking the arguments. And the constraint set implementation should +be able to unify the two assignments to `A`. + `covariant.py`: ```py From d6a98d8f2ce6d4f6fcae1613b31607c9b201656c Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Thu, 18 Dec 2025 20:59:34 -0500 Subject: [PATCH 15/16] mdlint --- .../resources/mdtest/type_properties/quantification.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md b/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md index 70b6b6f3d3a8c..d3877275cd032 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/quantification.md @@ -37,9 +37,9 @@ def keep_single[T](): ## Removing a single typevar -If a constraint set only mentions a single typevar, and we remove that typevar when quantifying, -the result is usually "always". The only exception is if the original constraint set has no -solution. In that case, the result is also unsatisfiable. +If a constraint set only mentions a single typevar, and we remove that typevar when quantifying, the +result is usually "always". The only exception is if the original constraint set has no solution. In +that case, the result is also unsatisfiable. ```py from ty_extensions import ConstraintSet, static_assert From e1f9ba7f05bcd254433a93273061930f5c66cb46 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Fri, 19 Dec 2025 13:35:33 -0500 Subject: [PATCH 16/16] don't substitute recursively! --- .../src/types/constraints.rs | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 8a8a9ee664d87..b7474145257c4 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -3198,7 +3198,12 @@ impl<'db> SequentMap<'db> { // (Covariant[B] ≤ C ≤ CU) ∧ (BL ≤ B ≤ BU) → (Covariant[BL] ≤ C ≤ CU) TypeVarVariance::Covariant => { - if !bound_lower.is_never() && !bound_lower.is_object() { + // Only substitute to create a new sequent if the substitution is interesting, and + // doesn't recursively contain the typevar we are substituting for. + if !bound_lower.is_never() + && !bound_lower.is_object() + && bound_lower.variance_of(db, bound_typevar) == TypeVarVariance::Bivariant + { let partial = PartialSpecialization::Single { bound_typevar, ty: bound_lower, @@ -3222,7 +3227,9 @@ impl<'db> SequentMap<'db> { } // (Covariant[BU] ≤ C ≤ CU) ∧ (BL ≤ B ≤ BU) → (Covariant[B] ≤ C ≤ CU) - if let Type::TypeVar(bound_upper_typevar) = bound_upper { + if let Type::TypeVar(bound_upper_typevar) = bound_upper + && !bound_upper_typevar.is_same_typevar_as(db, constrained_typevar) + { if constrained_lower.variance_of(db, bound_upper_typevar) == TypeVarVariance::Covariant { let partial = PartialSpecialization::Single { @@ -3250,7 +3257,12 @@ impl<'db> SequentMap<'db> { // (CL ≤ C ≤ Covariant[B]) ∧ (BL ≤ B ≤ BU) → (CL ≤ C ≤ Covariant[BU]) TypeVarVariance::Covariant => { - if !bound_upper.is_never() && !bound_upper.is_object() { + // Only substitute to create a new sequent if the substitution is interesting, and + // doesn't recursively contain the typevar we are substituting for. + if !bound_upper.is_never() + && !bound_upper.is_object() + && bound_upper.variance_of(db, bound_typevar) == TypeVarVariance::Bivariant + { let partial = PartialSpecialization::Single { bound_typevar, ty: bound_upper, @@ -3274,7 +3286,9 @@ impl<'db> SequentMap<'db> { } // (CL ≤ C ≤ Covariant[BL]) ∧ (BL ≤ B ≤ BU) → (CL ≤ C ≤ Covariant[B]) - if let Type::TypeVar(bound_lower_typevar) = bound_lower { + if let Type::TypeVar(bound_lower_typevar) = bound_lower + && !bound_lower_typevar.is_same_typevar_as(db, constrained_typevar) + { if constrained_upper.variance_of(db, bound_lower_typevar) == TypeVarVariance::Covariant { let partial = PartialSpecialization::Single {