Skip to content

Commit db7ced5

Browse files
committed
[ty] Implement equivalence for protocols with method members
1 parent 333191b commit db7ced5

File tree

4 files changed

+91
-17
lines changed

4 files changed

+91
-17
lines changed

crates/ty_python_semantic/resources/mdtest/protocols.md

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,8 +1476,7 @@ class P1(Protocol):
14761476
class P2(Protocol):
14771477
def x(self, y: int) -> None: ...
14781478

1479-
# TODO: this should pass
1480-
static_assert(is_equivalent_to(P1, P2)) # error: [static-assert-error]
1479+
static_assert(is_equivalent_to(P1, P2))
14811480
```
14821481

14831482
As with protocols that only have non-method members, this also holds true when they appear in
@@ -1487,8 +1486,7 @@ differently ordered unions:
14871486
class A: ...
14881487
class B: ...
14891488

1490-
# TODO: this should pass
1491-
static_assert(is_equivalent_to(A | B | P1, P2 | B | A)) # error: [static-assert-error]
1489+
static_assert(is_equivalent_to(A | B | P1, P2 | B | A))
14921490
```
14931491

14941492
## Narrowing of protocols
@@ -1881,6 +1879,86 @@ if isinstance(obj, (B, A)):
18811879
reveal_type(obj) # revealed: (Unknown & B) | (Unknown & A)
18821880
```
18831881

1882+
### Protocols that use `Self`
1883+
1884+
`Self` is a `TypeVar` with an upper bound of the class in which it is defined. This means that
1885+
`Self` annotations in protocols can also be tricky to handle without infinite recursion and stack
1886+
overflows.
1887+
1888+
```toml
1889+
[environment]
1890+
python-version = "3.12"
1891+
```
1892+
1893+
```py
1894+
from typing_extensions import Protocol, Self
1895+
from ty_extensions import static_assert
1896+
1897+
class _HashObject(Protocol):
1898+
def copy(self) -> Self: ...
1899+
1900+
class Foo: ...
1901+
1902+
# Attempting to build this union caused us to overflow on an early version of
1903+
# <https://github.com/astral-sh/ruff/pull/18659>
1904+
x: Foo | _HashObject
1905+
```
1906+
1907+
Some other similar cases that caused issues in our early `Protocol` implementation:
1908+
1909+
`a.py`:
1910+
1911+
```py
1912+
from typing_extensions import Protocol, Self
1913+
1914+
class PGconn(Protocol):
1915+
def connect(self) -> Self: ...
1916+
1917+
class Connection:
1918+
pgconn: PGconn
1919+
1920+
def is_crdb(conn: PGconn) -> bool:
1921+
return isinstance(conn, Connection)
1922+
```
1923+
1924+
and:
1925+
1926+
`b.py`:
1927+
1928+
```py
1929+
from typing_extensions import Protocol
1930+
1931+
class PGconn(Protocol):
1932+
def connect[T: PGconn](self: T) -> T: ...
1933+
1934+
class Connection:
1935+
pgconn: PGconn
1936+
1937+
def f(x: PGconn):
1938+
isinstance(x, Connection)
1939+
```
1940+
1941+
### Recursive protocols used as the first argument to `cast()`
1942+
1943+
These caused issues in an early version of our `Protocol` implementation due to the fact that we use
1944+
a recursive function in our `cast()` implementation to check whether a type contains `Unknown` or
1945+
`Todo`. Recklessly recursing into a type causes stack overflows if the type is recursive:
1946+
1947+
```toml
1948+
[environment]
1949+
python-version = "3.12"
1950+
```
1951+
1952+
```py
1953+
from typing import cast, Protocol
1954+
1955+
class Iterator[T](Protocol):
1956+
def __iter__(self) -> Iterator[T]: ...
1957+
1958+
def f(value: Iterator):
1959+
cast(Iterator, value) # error: [redundant-cast]
1960+
```
1961+
18841962
## TODO
18851963

18861964
Add tests for:

crates/ty_python_semantic/src/types.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,8 +1350,7 @@ impl<'db> Type<'db> {
13501350
.is_some_and(|instance| instance.has_relation_to(db, target, relation)),
13511351

13521352
(Type::FunctionLiteral(self_function_literal), Type::Callable(_)) => {
1353-
self_function_literal
1354-
.into_callable_type(db)
1353+
Type::Callable(self_function_literal.into_callable_type(db))
13551354
.has_relation_to(db, target, relation)
13561355
}
13571356

crates/ty_python_semantic/src/types/function.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,8 +767,8 @@ impl<'db> FunctionType<'db> {
767767
}
768768

769769
/// Convert the `FunctionType` into a [`Type::Callable`].
770-
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> {
771-
Type::Callable(CallableType::new(db, self.signature(db), false))
770+
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> {
771+
CallableType::new(db, self.signature(db), false)
772772
}
773773

774774
/// Convert the `FunctionType` into a [`Type::BoundMethod`].

crates/ty_python_semantic/src/types/protocol_class.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ impl<'db> ProtocolMemberData<'db> {
259259

260260
#[derive(Debug, Copy, Clone, PartialEq, Eq, salsa::Update, Hash)]
261261
enum ProtocolMemberKind<'db> {
262-
Method(Type<'db>), // TODO: use CallableType
262+
Method(CallableType<'db>),
263263
Property(PropertyInstanceType<'db>),
264264
Other(Type<'db>),
265265
}
@@ -334,7 +334,7 @@ fn walk_protocol_member<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>(
334334
visitor: &mut V,
335335
) {
336336
match member.kind {
337-
ProtocolMemberKind::Method(method) => visitor.visit_type(db, method),
337+
ProtocolMemberKind::Method(method) => visitor.visit_callable_type(db, method),
338338
ProtocolMemberKind::Property(property) => {
339339
visitor.visit_property_instance_type(db, property);
340340
}
@@ -353,7 +353,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
353353

354354
fn ty(&self) -> Type<'db> {
355355
match &self.kind {
356-
ProtocolMemberKind::Method(callable) => *callable,
356+
ProtocolMemberKind::Method(callable) => Type::Callable(*callable),
357357
ProtocolMemberKind::Property(property) => Type::PropertyInstance(*property),
358358
ProtocolMemberKind::Other(ty) => *ty,
359359
}
@@ -500,13 +500,10 @@ fn cached_protocol_interface<'db>(
500500
(Type::Callable(callable), BoundOnClass::Yes)
501501
if callable.is_function_like(db) =>
502502
{
503-
ProtocolMemberKind::Method(ty)
503+
ProtocolMemberKind::Method(callable)
504504
}
505-
// TODO: method members that have `FunctionLiteral` types should be upcast
506-
// to `CallableType` so that two protocols with identical method members
507-
// are recognized as equivalent.
508-
(Type::FunctionLiteral(_function), BoundOnClass::Yes) => {
509-
ProtocolMemberKind::Method(ty)
505+
(Type::FunctionLiteral(function), BoundOnClass::Yes) => {
506+
ProtocolMemberKind::Method(function.into_callable_type(db))
510507
}
511508
_ => ProtocolMemberKind::Other(ty),
512509
};

0 commit comments

Comments
 (0)