Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions crates/ty_python_semantic/resources/mdtest/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -3010,6 +3010,31 @@ class Bar(Protocol[S]):
z: S | Bar[S]
```

### Recursive generic protocols with growing specializations

This snippet caused a stack overflow in <https://github.com/astral-sh/ty/issues/1736> because the
type parameter grows with each recursive call (`C[set[T]]` leads to `C[set[set[T]]]`, then
`C[set[set[set[T]]]]`, etc.):

```toml
[environment]
python-version = "3.12"
```

```py
from typing import Protocol

class C[T](Protocol):
a: "C[set[T]]"

def takes_c(c: C[set[int]]) -> None: ...
def f(c: C[int]) -> None:
# The key thing is that we don't stack overflow while checking this.
# The cycle detection assumes compatibility when it detects potential
# infinite recursion between protocol specializations.
takes_c(c)
```

### Recursive legacy generic protocol

```py
Expand Down
107 changes: 85 additions & 22 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,48 @@ fn definition_expression_type<'db>(
/// A [`TypeTransformer`] that is used in `apply_type_mapping` methods.
pub(crate) type ApplyTypeMappingVisitor<'db> = TypeTransformer<'db, TypeMapping<'db, 'db>>;

/// A [`PairVisitor`] that is used in `has_relation_to` methods.
pub(crate) type HasRelationToVisitor<'db> =
CycleDetector<TypeRelation<'db>, (Type<'db>, Type<'db>, TypeRelation<'db>), ConstraintSet<'db>>;
/// Key type for the `has_relation_to` visitor.
///
/// For most type comparisons, we use the full `Type` as the key. However, for protocol-to-protocol
/// comparisons, we use the underlying `ClassLiteral` (ignoring specialization) to detect infinite
/// recursion that occurs with recursive generic protocols.
///
/// For example, with:
/// ```python
/// class C[T](Protocol):
/// a: 'C[set[T]]'
/// ```
///
/// Checking `C[set[int]] <: C[set[int]]` leads to checking `C[set[set[int]]] <: C[set[set[int]]]`,
/// then `C[set[set[set[int]]]] <: C[set[set[set[int]]]]`, etc. Each level has different type
/// specializations, so using full types as keys doesn't detect the cycle. By using `ClassLiteral`
/// as the key for protocol comparisons, we detect that we're comparing protocol `C` against itself
/// regardless of specialization.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) enum TypeRelationKey<'db> {
/// A regular type - used for most comparisons.
Type(Type<'db>),
/// A protocol class literal (without specialization) - used for protocol-to-protocol comparisons
/// to detect recursive generic protocols.
ProtocolClass(ClassLiteral<'db>),
}

impl<'db> From<Type<'db>> for TypeRelationKey<'db> {
fn from(ty: Type<'db>) -> Self {
TypeRelationKey::Type(ty)
}
}

/// A [`CycleDetector`] that is used in `has_relation_to` methods.
pub(crate) type HasRelationToVisitor<'db> = CycleDetector<
TypeRelation<'db>,
(
TypeRelationKey<'db>,
TypeRelationKey<'db>,
TypeRelation<'db>,
),
ConstraintSet<'db>,
>;

impl Default for HasRelationToVisitor<'_> {
fn default() -> Self {
Expand Down Expand Up @@ -1973,7 +2012,7 @@ impl<'db> Type<'db> {
}

(Type::TypeAlias(self_alias), _) => {
relation_visitor.visit((self, target, relation), || {
relation_visitor.visit((self.into(), target.into(), relation), || {
self_alias.value_type(db).has_relation_to_impl(
db,
target,
Expand All @@ -1986,7 +2025,7 @@ impl<'db> Type<'db> {
}

(_, Type::TypeAlias(target_alias)) => {
relation_visitor.visit((self, target, relation), || {
relation_visitor.visit((self.into(), target.into(), relation), || {
self.has_relation_to_impl(
db,
target_alias.value_type(db),
Expand Down Expand Up @@ -2452,7 +2491,7 @@ impl<'db> Type<'db> {
) => ConstraintSet::from(false),

(Type::Callable(self_callable), Type::Callable(other_callable)) => relation_visitor
.visit((self, target, relation), || {
.visit((self.into(), target.into(), relation), || {
self_callable.has_relation_to_impl(
db,
other_callable,
Expand All @@ -2464,7 +2503,7 @@ impl<'db> Type<'db> {
}),

(_, Type::Callable(other_callable)) => {
relation_visitor.visit((self, target, relation), || {
relation_visitor.visit((self.into(), target.into(), relation), || {
self.try_upcast_to_callable(db).when_some_and(|callables| {
callables.has_relation_to_impl(
db,
Expand Down Expand Up @@ -2499,7 +2538,26 @@ impl<'db> Type<'db> {
}

(_, Type::ProtocolInstance(protocol)) => {
relation_visitor.visit((self, target, relation), || {
// For protocol-to-protocol comparisons, use ClassLiteral keys to detect
// infinite recursion with recursive generic protocols (e.g., `class C[T](Protocol): a: C[set[T]]`).
// When both types are protocols of the same class, the types may differ due to
// different specializations, but comparing them would lead to infinite recursion.
let (self_key, target_key) = if let Type::ProtocolInstance(self_protocol) = self {
// Both are protocol instances - try to use class literals as keys
// for detecting cycles in recursive generic protocols
match (self_protocol.class_literal(db), protocol.class_literal(db)) {
(Some(self_class), Some(target_class)) => (
TypeRelationKey::ProtocolClass(self_class),
TypeRelationKey::ProtocolClass(target_class),
),
// One or both are synthesized protocols - fall back to full types
_ => (self.into(), target.into()),
}
} else {
// Source is not a protocol - use full types
(self.into(), target.into())
};
relation_visitor.visit((self_key, target_key, relation), || {
self.satisfies_protocol(
db,
protocol,
Expand All @@ -2515,7 +2573,7 @@ impl<'db> Type<'db> {
(Type::ProtocolInstance(_), _) => ConstraintSet::from(false),

(Type::TypedDict(self_typeddict), Type::TypedDict(other_typeddict)) => relation_visitor
.visit((self, target, relation), || {
.visit((self.into(), target.into(), relation), || {
self_typeddict.has_relation_to_impl(
db,
other_typeddict,
Expand All @@ -2530,18 +2588,23 @@ impl<'db> Type<'db> {
// compatible `Mapping`s. `extra_items` could also allow for some assignments to `dict`, as
// long as `total=False`. (But then again, does anyone want a non-total `TypedDict` where all
// key types are a supertype of the extra items type?)
(Type::TypedDict(_), _) => relation_visitor.visit((self, target, relation), || {
KnownClass::Mapping
.to_specialized_instance(db, [KnownClass::Str.to_instance(db), Type::object()])
.has_relation_to_impl(
db,
target,
inferable,
relation,
relation_visitor,
disjointness_visitor,
)
}),
(Type::TypedDict(_), _) => {
relation_visitor.visit((self.into(), target.into(), relation), || {
KnownClass::Mapping
.to_specialized_instance(
db,
[KnownClass::Str.to_instance(db), Type::object()],
)
.has_relation_to_impl(
db,
target,
inferable,
relation,
relation_visitor,
disjointness_visitor,
)
})
}

// A non-`TypedDict` cannot subtype a `TypedDict`
(_, Type::TypedDict(_)) => ConstraintSet::from(false),
Expand Down Expand Up @@ -2841,7 +2904,7 @@ impl<'db> Type<'db> {
// `bool` is a subtype of `int`, because `bool` subclasses `int`,
// which means that all instances of `bool` are also instances of `int`
(Type::NominalInstance(self_instance), Type::NominalInstance(target_instance)) => {
relation_visitor.visit((self, target, relation), || {
relation_visitor.visit((self.into(), target.into(), relation), || {
self_instance.has_relation_to_impl(
db,
target_instance,
Expand Down
10 changes: 10 additions & 0 deletions crates/ty_python_semantic/src/types/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,16 @@ impl<'db> ProtocolInstanceType<'db> {
}
}

/// If this is a class-based protocol, return its class literal (without specialization).
///
/// Returns `None` for synthesized protocols that don't correspond to a class definition.
pub(super) fn class_literal(self, db: &'db dyn Db) -> Option<ClassLiteral<'db>> {
match self.inner {
Protocol::FromClass(class) => Some(class.class_literal(db).0),
Protocol::Synthesized(_) => None,
}
}

/// Return the meta-type of this protocol-instance type.
pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> {
match self.inner {
Expand Down
Loading