Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 74 additions & 74 deletions crates/ty_python_semantic/src/types/relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,16 @@ impl<'a, 'c, 'db> TypeRelationChecker<'a, 'c, 'db> {
ConstraintSet::from_bool(self.constraints, false)
}

fn with_recursion_guard(
&self,
source: Type<'db>,
target: Type<'db>,
work: impl FnOnce() -> ConstraintSet<'db, 'c>,
) -> ConstraintSet<'db, 'c> {
self.relation_visitor
.visit((source, target, self.relation), work)
}

pub(super) fn check_type_pair(
&self,
db: &'db dyn Db,
Expand Down Expand Up @@ -654,17 +664,13 @@ impl<'a, 'c, 'db> TypeRelationChecker<'a, 'c, 'db> {
ConstraintSet::from_bool(self.constraints, self.relation.is_assignability())
}

(Type::TypeAlias(source_alias), _) => self
.relation_visitor
.visit((source, target, self.relation), || {
self.check_type_pair(db, source_alias.value_type(db), target)
}),
(Type::TypeAlias(source_alias), _) => self.with_recursion_guard(source, target, || {
self.check_type_pair(db, source_alias.value_type(db), target)
}),

(_, Type::TypeAlias(target_alias)) => self
.relation_visitor
.visit((source, target, self.relation), || {
self.check_type_pair(db, source, target_alias.value_type(db))
}),
(_, Type::TypeAlias(target_alias)) => self.with_recursion_guard(source, target, || {
self.check_type_pair(db, source, target_alias.value_type(db))
}),

// Pretend that instances of `dataclasses.Field` are assignable to their default type.
// This allows field definitions like `name: str = field(default="")` in dataclasses
Expand Down Expand Up @@ -1068,11 +1074,11 @@ impl<'a, 'c, 'db> TypeRelationChecker<'a, 'c, 'db> {
ConstraintSet::from_bool(self.constraints, source.bool(db).is_always_true())
}
// Currently, the only supertype of `AlwaysFalsy` and `AlwaysTruthy` is the universal set (object instance).
(Type::AlwaysFalsy | Type::AlwaysTruthy, _) => self
.relation_visitor
.visit((source, target, self.relation), || {
(Type::AlwaysFalsy | Type::AlwaysTruthy, _) => {
self.with_recursion_guard(source, target, || {
self.check_type_pair(db, Type::object(), target)
}),
})
}

// These clauses handle type variants that include function literals. A function
// literal is the subtype of itself, and not of any other function literal. However,
Expand Down Expand Up @@ -1127,23 +1133,18 @@ impl<'a, 'c, 'db> TypeRelationChecker<'a, 'c, 'db> {
) => self.never(),

(Type::Callable(source_callable), Type::Callable(target_callable)) => self
.relation_visitor
.visit((source, target, self.relation), || {
.with_recursion_guard(source, target, || {
self.check_callable_pair(db, source_callable, target_callable)
}),

(_, Type::Callable(target_callable)) => {
self.relation_visitor
.visit((source, target, self.relation), || {
source
.try_upcast_to_callable_with_policy(
db,
UpcastPolicy::from(self.relation),
)
.when_some_and(db, self.constraints, |callables| {
self.check_callables_vs_callable(db, &callables, target_callable)
})
})
self.with_recursion_guard(source, target, || {
source
.try_upcast_to_callable_with_policy(db, UpcastPolicy::from(self.relation))
.when_some_and(db, self.constraints, |callables| {
self.check_callables_vs_callable(db, &callables, target_callable)
})
})
}

// `type[Any]` is assignable to arbitrary protocols as it has arbitrary attributes
Expand All @@ -1159,33 +1160,30 @@ impl<'a, 'c, 'db> TypeRelationChecker<'a, 'c, 'db> {
self.check_type_pair(db, KnownClass::Type.to_instance(db), target)
}

(_, Type::ProtocolInstance(target_proto)) => self
.relation_visitor
.visit((source, target, self.relation), || {
(_, Type::ProtocolInstance(target_proto)) => {
self.with_recursion_guard(source, target, || {
self.check_type_satisfies_protocol(db, source, target_proto)
}),
})
}

// A protocol instance can never be a subtype of a nominal type, with the *sole* exception of `object`.
(Type::ProtocolInstance(_), _) => self.never(),

(Type::TypedDict(source_td), Type::TypedDict(target_td)) => self
.relation_visitor
.visit((source, target, self.relation), || {
(Type::TypedDict(source_td), Type::TypedDict(target_td)) => {
self.with_recursion_guard(source, target, || {
self.check_typeddict_pair(db, source_td, target_td)
}),
})
}

// TODO: When we support `closed` and/or `extra_items`, we could allow assignments to other
// compatible `Mapping`s. `extra_items` could also allow for some assignments to `dict`, as
// long as `total=False`. (But then again, does anyone want a non-total `TypedDict` where all
// key types are a supertype of the extra items type?)
(Type::TypedDict(_), _) => {
self.relation_visitor
.visit((source, target, self.relation), || {
let spec = &[KnownClass::Str.to_instance(db), Type::object()];
let str_object_map = KnownClass::Mapping.to_specialized_instance(db, spec);
self.check_type_pair(db, str_object_map, target)
})
}
(Type::TypedDict(_), _) => self.with_recursion_guard(source, target, || {
let spec = &[KnownClass::Str.to_instance(db), Type::object()];
let str_object_map = KnownClass::Mapping.to_specialized_instance(db, spec);
self.check_type_pair(db, str_object_map, target)
}),

// A non-`TypedDict` cannot subtype a `TypedDict`
(_, Type::TypedDict(_)) => self.never(),
Expand Down Expand Up @@ -1480,14 +1478,12 @@ impl<'a, 'c, 'db> TypeRelationChecker<'a, 'c, 'db> {
// `bool` is a subtype of `int`, because `bool` subclasses `int`,
// which means that all instances of `bool` are also instances of `int`
(Type::NominalInstance(source_i), Type::NominalInstance(target_i)) => self
.relation_visitor
.visit((source, target, self.relation), || {
.with_recursion_guard(source, target, || {
self.check_nominal_instance_pair(db, source_i, target_i)
}),

(Type::PropertyInstance(source_p), Type::PropertyInstance(target_p)) => self
.relation_visitor
.visit((source, target, self.relation), || {
.with_recursion_guard(source, target, || {
self.check_property_instance_pair(db, source_p, target_p)
}),

Expand Down Expand Up @@ -1643,6 +1639,15 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> {
}
}

fn with_recursion_guard(
&self,
source: Type<'db>,
target: Type<'db>,
work: impl FnOnce() -> ConstraintSet<'db, 'c>,
) -> ConstraintSet<'db, 'c> {
self.disjointness_visitor.visit((source, target), work)
}

fn any_protocol_members_absent_or_disjoint(
&self,
db: &'db dyn Db,
Expand Down Expand Up @@ -1684,14 +1689,14 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> {

(Type::TypeAlias(alias), _) => {
let left_alias_ty = alias.value_type(db);
self.disjointness_visitor.visit((left, right), || {
self.with_recursion_guard(left, right, || {
self.check_type_pair(db, left_alias_ty, right)
})
}

(_, Type::TypeAlias(alias)) => {
let right_alias_ty = alias.value_type(db);
self.disjointness_visitor.visit((left, right), || {
self.with_recursion_guard(left, right, || {
self.check_type_pair(db, left, right_alias_ty)
})
}
Expand Down Expand Up @@ -1787,8 +1792,8 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> {
// If we have two intersections, we test the positive elements of each one against the other intersection
// Negative elements need a positive element on the other side in order to be disjoint.
// This is similar to what would happen if we tried to build a new intersection that combines the two
(Type::Intersection(left_intersection), Type::Intersection(right_intersection)) => {
self.disjointness_visitor.visit((left, right), || {
(Type::Intersection(left_intersection), Type::Intersection(right_intersection)) => self
.with_recursion_guard(left, right, || {
left_intersection
.positive(db)
.iter()
Expand All @@ -1802,12 +1807,11 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> {
|&pos_ty| self.check_type_pair(db, pos_ty, left),
)
})
})
}
}),

(Type::Intersection(intersection), other)
| (other, Type::Intersection(intersection)) => {
self.disjointness_visitor.visit((left, right), || {
self.with_recursion_guard(left, right, || {
intersection
.positive(db)
.iter()
Expand Down Expand Up @@ -1904,33 +1908,30 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> {
ConstraintSet::from_bool(self.constraints, ty.bool(db).is_always_true())
}

(Type::ProtocolInstance(left_proto), Type::ProtocolInstance(right_proto)) => {
self.disjointness_visitor.visit((left, right), || {
(Type::ProtocolInstance(left_proto), Type::ProtocolInstance(right_proto)) => self
.with_recursion_guard(left, right, || {
self.check_protocol_instance_pair(db, left_proto, right_proto)
})
}
}),

(Type::ProtocolInstance(protocol), Type::SpecialForm(special_form))
| (Type::SpecialForm(special_form), Type::ProtocolInstance(protocol)) => {
self.disjointness_visitor.visit((left, right), || {
| (Type::SpecialForm(special_form), Type::ProtocolInstance(protocol)) => self
.with_recursion_guard(left, right, || {
self.any_protocol_members_absent_or_disjoint(
db,
protocol,
special_form.instance_fallback(db),
)
})
}
}),

(Type::ProtocolInstance(protocol), Type::KnownInstance(known_instance))
| (Type::KnownInstance(known_instance), Type::ProtocolInstance(protocol)) => {
self.disjointness_visitor.visit((left, right), || {
| (Type::KnownInstance(known_instance), Type::ProtocolInstance(protocol)) => self
.with_recursion_guard(left, right, || {
self.any_protocol_members_absent_or_disjoint(
db,
protocol,
known_instance.instance_fallback(db),
)
})
}
}),

// The absence of a protocol member on one of these types guarantees
// that the type will be disjoint from the protocol,
Expand Down Expand Up @@ -1974,7 +1975,7 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> {
| Type::FunctionLiteral(..)
| Type::ModuleLiteral(..)
| Type::GenericAlias(..)),
) => self.disjointness_visitor.visit((left, right), || {
) => self.with_recursion_guard(left, right, || {
self.any_protocol_members_absent_or_disjoint(db, protocol, ty)
}),

Expand All @@ -1985,7 +1986,7 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> {
| (Type::NominalInstance(nominal), Type::ProtocolInstance(protocol))
if nominal.class(db).is_final(db) =>
{
self.disjointness_visitor.visit((left, right), || {
self.with_recursion_guard(left, right, || {
self.any_protocol_members_absent_or_disjoint(
db,
protocol,
Expand All @@ -1996,7 +1997,7 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> {

(Type::ProtocolInstance(protocol), other)
| (other, Type::ProtocolInstance(protocol)) => {
self.disjointness_visitor.visit((left, right), || {
self.with_recursion_guard(left, right, || {
protocol
.interface(db)
.members(db)
Expand Down Expand Up @@ -2253,11 +2254,10 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> {
)
}

(Type::NominalInstance(left_i), Type::NominalInstance(right_i)) => {
self.disjointness_visitor.visit((left, right), || {
(Type::NominalInstance(left_i), Type::NominalInstance(right_i)) => self
.with_recursion_guard(left, right, || {
self.check_nominal_instance_pair(db, left_i, right_i)
})
}
}),

(Type::NewTypeInstance(left), Type::NewTypeInstance(right)) => {
self.check_newtype_pair(db, left, right)
Expand All @@ -2282,7 +2282,7 @@ impl<'a, 'c, 'db> DisjointnessChecker<'a, 'c, 'db> {
(Type::GenericAlias(_), _) | (_, Type::GenericAlias(_)) => self.always(),

(Type::TypedDict(left_td), Type::TypedDict(right_td)) => {
self.disjointness_visitor.visit((left, right), || {
self.with_recursion_guard(left, right, || {
self.check_typeddict_pair(db, left_td, right_td)
})
}
Expand Down
Loading