diff --git a/crates/ty_python_semantic/src/types/relation.rs b/crates/ty_python_semantic/src/types/relation.rs index 34fbde3bb0856..ca6e930c71323 100644 --- a/crates/ty_python_semantic/src/types/relation.rs +++ b/crates/ty_python_semantic/src/types/relation.rs @@ -578,6 +578,16 @@ impl<'a, 'c, 'db> TypeRelationChecker<'a, 'c, 'db> { ConstraintSet::from_bool(self.constraints, false) } + fn with_recursion_guard( + &self, + source: Type<'db>, + target: Type<'db>, + work: impl FnOnce() -> ConstraintSet<'db, 'c>, + ) -> ConstraintSet<'db, 'c> { + self.relation_visitor + .visit((source, target, self.relation), work) + } + pub(super) fn check_type_pair( &self, db: &'db dyn Db, @@ -654,17 +664,13 @@ impl<'a, 'c, 'db> TypeRelationChecker<'a, 'c, 'db> { ConstraintSet::from_bool(self.constraints, self.relation.is_assignability()) } - (Type::TypeAlias(source_alias), _) => self - .relation_visitor - .visit((source, target, self.relation), || { - self.check_type_pair(db, source_alias.value_type(db), target) - }), + (Type::TypeAlias(source_alias), _) => self.with_recursion_guard(source, target, || { + self.check_type_pair(db, source_alias.value_type(db), target) + }), - (_, Type::TypeAlias(target_alias)) => self - .relation_visitor - .visit((source, target, self.relation), || { - self.check_type_pair(db, source, target_alias.value_type(db)) - }), + (_, Type::TypeAlias(target_alias)) => self.with_recursion_guard(source, target, || { + self.check_type_pair(db, source, target_alias.value_type(db)) + }), // Pretend that instances of `dataclasses.Field` are assignable to their default type. // This allows field definitions like `name: str = field(default="")` in dataclasses @@ -1068,11 +1074,11 @@ impl<'a, 'c, 'db> TypeRelationChecker<'a, 'c, 'db> { ConstraintSet::from_bool(self.constraints, source.bool(db).is_always_true()) } // Currently, the only supertype of `AlwaysFalsy` and `AlwaysTruthy` is the universal set (object instance). - (Type::AlwaysFalsy | Type::AlwaysTruthy, _) => self - .relation_visitor - .visit((source, target, self.relation), || { + (Type::AlwaysFalsy | Type::AlwaysTruthy, _) => { + self.with_recursion_guard(source, target, || { self.check_type_pair(db, Type::object(), target) - }), + }) + } // These clauses handle type variants that include function literals. A function // literal is the subtype of itself, and not of any other function literal. However, @@ -1127,23 +1133,18 @@ impl<'a, 'c, 'db> TypeRelationChecker<'a, 'c, 'db> { ) => self.never(), (Type::Callable(source_callable), Type::Callable(target_callable)) => self - .relation_visitor - .visit((source, target, self.relation), || { + .with_recursion_guard(source, target, || { self.check_callable_pair(db, source_callable, target_callable) }), (_, Type::Callable(target_callable)) => { - self.relation_visitor - .visit((source, target, self.relation), || { - source - .try_upcast_to_callable_with_policy( - db, - UpcastPolicy::from(self.relation), - ) - .when_some_and(db, self.constraints, |callables| { - self.check_callables_vs_callable(db, &callables, target_callable) - }) - }) + self.with_recursion_guard(source, target, || { + source + .try_upcast_to_callable_with_policy(db, UpcastPolicy::from(self.relation)) + .when_some_and(db, self.constraints, |callables| { + self.check_callables_vs_callable(db, &callables, target_callable) + }) + }) } // `type[Any]` is assignable to arbitrary protocols as it has arbitrary attributes @@ -1159,33 +1160,30 @@ impl<'a, 'c, 'db> TypeRelationChecker<'a, 'c, 'db> { self.check_type_pair(db, KnownClass::Type.to_instance(db), target) } - (_, Type::ProtocolInstance(target_proto)) => self - .relation_visitor - .visit((source, target, self.relation), || { + (_, Type::ProtocolInstance(target_proto)) => { + self.with_recursion_guard(source, target, || { self.check_type_satisfies_protocol(db, source, target_proto) - }), + }) + } // A protocol instance can never be a subtype of a nominal type, with the *sole* exception of `object`. (Type::ProtocolInstance(_), _) => self.never(), - (Type::TypedDict(source_td), Type::TypedDict(target_td)) => self - .relation_visitor - .visit((source, target, self.relation), || { + (Type::TypedDict(source_td), Type::TypedDict(target_td)) => { + self.with_recursion_guard(source, target, || { self.check_typeddict_pair(db, source_td, target_td) - }), + }) + } // TODO: When we support `closed` and/or `extra_items`, we could allow assignments to other // compatible `Mapping`s. `extra_items` could also allow for some assignments to `dict`, as // long as `total=False`. (But then again, does anyone want a non-total `TypedDict` where all // key types are a supertype of the extra items type?) - (Type::TypedDict(_), _) => { - self.relation_visitor - .visit((source, target, self.relation), || { - let spec = &[KnownClass::Str.to_instance(db), Type::object()]; - let str_object_map = KnownClass::Mapping.to_specialized_instance(db, spec); - self.check_type_pair(db, str_object_map, target) - }) - } + (Type::TypedDict(_), _) => self.with_recursion_guard(source, target, || { + let spec = &[KnownClass::Str.to_instance(db), Type::object()]; + let str_object_map = KnownClass::Mapping.to_specialized_instance(db, spec); + self.check_type_pair(db, str_object_map, target) + }), // A non-`TypedDict` cannot subtype a `TypedDict` (_, Type::TypedDict(_)) => self.never(), @@ -1480,14 +1478,12 @@ impl<'a, 'c, 'db> TypeRelationChecker<'a, 'c, 'db> { // `bool` is a subtype of `int`, because `bool` subclasses `int`, // which means that all instances of `bool` are also instances of `int` (Type::NominalInstance(source_i), Type::NominalInstance(target_i)) => self - .relation_visitor - .visit((source, target, self.relation), || { + .with_recursion_guard(source, target, || { self.check_nominal_instance_pair(db, source_i, target_i) }), (Type::PropertyInstance(source_p), Type::PropertyInstance(target_p)) => self - .relation_visitor - .visit((source, target, self.relation), || { + .with_recursion_guard(source, target, || { self.check_property_instance_pair(db, source_p, target_p) }), @@ -1643,6 +1639,15 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> { } } + fn with_recursion_guard( + &self, + source: Type<'db>, + target: Type<'db>, + work: impl FnOnce() -> ConstraintSet<'db, 'c>, + ) -> ConstraintSet<'db, 'c> { + self.disjointness_visitor.visit((source, target), work) + } + fn any_protocol_members_absent_or_disjoint( &self, db: &'db dyn Db, @@ -1684,14 +1689,14 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> { (Type::TypeAlias(alias), _) => { let left_alias_ty = alias.value_type(db); - self.disjointness_visitor.visit((left, right), || { + self.with_recursion_guard(left, right, || { self.check_type_pair(db, left_alias_ty, right) }) } (_, Type::TypeAlias(alias)) => { let right_alias_ty = alias.value_type(db); - self.disjointness_visitor.visit((left, right), || { + self.with_recursion_guard(left, right, || { self.check_type_pair(db, left, right_alias_ty) }) } @@ -1787,8 +1792,8 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> { // If we have two intersections, we test the positive elements of each one against the other intersection // Negative elements need a positive element on the other side in order to be disjoint. // This is similar to what would happen if we tried to build a new intersection that combines the two - (Type::Intersection(left_intersection), Type::Intersection(right_intersection)) => { - self.disjointness_visitor.visit((left, right), || { + (Type::Intersection(left_intersection), Type::Intersection(right_intersection)) => self + .with_recursion_guard(left, right, || { left_intersection .positive(db) .iter() @@ -1802,12 +1807,11 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> { |&pos_ty| self.check_type_pair(db, pos_ty, left), ) }) - }) - } + }), (Type::Intersection(intersection), other) | (other, Type::Intersection(intersection)) => { - self.disjointness_visitor.visit((left, right), || { + self.with_recursion_guard(left, right, || { intersection .positive(db) .iter() @@ -1904,33 +1908,30 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> { ConstraintSet::from_bool(self.constraints, ty.bool(db).is_always_true()) } - (Type::ProtocolInstance(left_proto), Type::ProtocolInstance(right_proto)) => { - self.disjointness_visitor.visit((left, right), || { + (Type::ProtocolInstance(left_proto), Type::ProtocolInstance(right_proto)) => self + .with_recursion_guard(left, right, || { self.check_protocol_instance_pair(db, left_proto, right_proto) - }) - } + }), (Type::ProtocolInstance(protocol), Type::SpecialForm(special_form)) - | (Type::SpecialForm(special_form), Type::ProtocolInstance(protocol)) => { - self.disjointness_visitor.visit((left, right), || { + | (Type::SpecialForm(special_form), Type::ProtocolInstance(protocol)) => self + .with_recursion_guard(left, right, || { self.any_protocol_members_absent_or_disjoint( db, protocol, special_form.instance_fallback(db), ) - }) - } + }), (Type::ProtocolInstance(protocol), Type::KnownInstance(known_instance)) - | (Type::KnownInstance(known_instance), Type::ProtocolInstance(protocol)) => { - self.disjointness_visitor.visit((left, right), || { + | (Type::KnownInstance(known_instance), Type::ProtocolInstance(protocol)) => self + .with_recursion_guard(left, right, || { self.any_protocol_members_absent_or_disjoint( db, protocol, known_instance.instance_fallback(db), ) - }) - } + }), // The absence of a protocol member on one of these types guarantees // that the type will be disjoint from the protocol, @@ -1974,7 +1975,7 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> { | Type::FunctionLiteral(..) | Type::ModuleLiteral(..) | Type::GenericAlias(..)), - ) => self.disjointness_visitor.visit((left, right), || { + ) => self.with_recursion_guard(left, right, || { self.any_protocol_members_absent_or_disjoint(db, protocol, ty) }), @@ -1985,7 +1986,7 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> { | (Type::NominalInstance(nominal), Type::ProtocolInstance(protocol)) if nominal.class(db).is_final(db) => { - self.disjointness_visitor.visit((left, right), || { + self.with_recursion_guard(left, right, || { self.any_protocol_members_absent_or_disjoint( db, protocol, @@ -1996,7 +1997,7 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> { (Type::ProtocolInstance(protocol), other) | (other, Type::ProtocolInstance(protocol)) => { - self.disjointness_visitor.visit((left, right), || { + self.with_recursion_guard(left, right, || { protocol .interface(db) .members(db) @@ -2253,11 +2254,10 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> { ) } - (Type::NominalInstance(left_i), Type::NominalInstance(right_i)) => { - self.disjointness_visitor.visit((left, right), || { + (Type::NominalInstance(left_i), Type::NominalInstance(right_i)) => self + .with_recursion_guard(left, right, || { self.check_nominal_instance_pair(db, left_i, right_i) - }) - } + }), (Type::NewTypeInstance(left), Type::NewTypeInstance(right)) => { self.check_newtype_pair(db, left, right) @@ -2282,7 +2282,7 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> { (Type::GenericAlias(_), _) | (_, Type::GenericAlias(_)) => self.always(), (Type::TypedDict(left_td), Type::TypedDict(right_td)) => { - self.disjointness_visitor.visit((left, right), || { + self.with_recursion_guard(left, right, || { self.check_typeddict_pair(db, left_td, right_td) }) }