Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions crates/ty_python_semantic/resources/corpus/cyclic_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Regression test for https://github.com/astral-sh/ty/issues/3080

# To reproduce the bug, deferred evaluation of type annotations must be applied.
from __future__ import annotations

from typing import Generic, Protocol, Self, TypeVar, overload

S = TypeVar("S")
T = TypeVar("T")


class Unit(Protocol):
def __mul__(self, other: S | Quantity[S]): ...


class Vector(Protocol): ...


class Quantity(Generic[T], Protocol):
@overload
def __mul__(self, other: Unit | Quantity[S]): ...

@overload
def __mul__(self, other: Vector) -> Vector: ...
7 changes: 6 additions & 1 deletion crates/ty_python_semantic/src/types/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,12 @@ impl<'db> FunctionType<'db> {
///
/// Were this not a salsa query, then the calling query
/// would depend on the function's AST and rerun for every change in that file.
#[salsa::tracked(returns(ref), cycle_initial=|_, _, _| CallableSignature::single(Signature::bottom()), heap_size=ruff_memory_usage::heap_size)]
#[salsa::tracked(
returns(ref),
cycle_initial=|_, _, _| CallableSignature::single(Signature::bottom()),
cycle_fn=|db, cycle, previous, value: CallableSignature<'db>, _| value.cycle_normalized(db, previous, cycle),
heap_size=ruff_memory_usage::heap_size,
)]
pub(crate) fn signature(self, db: &'db dyn Db) -> CallableSignature<'db> {
self.updated_signature(db)
.cloned()
Expand Down
75 changes: 74 additions & 1 deletion crates/ty_python_semantic/src/types/protocol_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,24 @@ impl<'db> ProtocolInterface<'db> {
Self::new(db, BTreeMap::default())
}

fn cycle_normalized(self, db: &'db dyn Db, previous: Self, cycle: &salsa::Cycle) -> Self {
let prev_inner = previous.inner(db);
let curr_inner = self.inner(db);

let members: BTreeMap<_, _> = curr_inner
.iter()
.map(|(name, curr_data)| {
let normalized = if let Some(prev_data) = prev_inner.get(name) {
curr_data.cycle_normalized(db, prev_data, cycle)
} else {
curr_data.clone()
};
(name.clone(), normalized)
})
.collect();
Self::new(db, members)
}

pub(super) fn members<'a>(
self,
db: &'db dyn Db,
Expand Down Expand Up @@ -404,6 +422,14 @@ pub(super) struct ProtocolMemberData<'db> {
}

impl<'db> ProtocolMemberData<'db> {
fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self {
Self {
kind: self.kind.cycle_normalized(db, &previous.kind, cycle),
qualifiers: self.qualifiers,
definition: self.definition,
}
}

fn recursive_type_normalized_impl(
&self,
db: &'db dyn Db,
Expand Down Expand Up @@ -509,6 +535,38 @@ enum ProtocolMemberKind<'db> {
}

impl<'db> ProtocolMemberKind<'db> {
fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self {
match (self, previous) {
(Self::Method(curr), Self::Method(prev)) => {
debug_assert_eq!(curr.kind(db), prev.kind(db));
let normalized =
curr.signatures(db)
.cycle_normalized(db, prev.signatures(db), cycle);
Self::Method(CallableType::new(db, normalized, curr.kind(db)))
}
(Self::Property(curr), Self::Property(prev)) => {
let getter = match (curr.getter(db), prev.getter(db)) {
(Some(curr), Some(prev)) => Some(curr.cycle_normalized(db, prev, cycle)),
(Some(curr), None) => Some(curr.recursive_type_normalized(db, cycle)),
(None, _) => None,
};
let setter = match (curr.setter(db), prev.setter(db)) {
(Some(curr), Some(prev)) => Some(curr.cycle_normalized(db, prev, cycle)),
(Some(curr), None) => Some(curr.recursive_type_normalized(db, cycle)),
(None, _) => None,
};
Self::Property(PropertyInstanceType::new(db, getter, setter))
}
(Self::Other(curr), Self::Other(prev)) => {
Self::Other(curr.cycle_normalized(db, *prev, cycle))
}
_ => {
debug_assert!(matches!(previous, Self::Other(ty) if ty.is_divergent()));
*self
}
}
}

fn apply_type_mapping_impl<'a>(
&self,
db: &'db dyn Db,
Expand Down Expand Up @@ -850,7 +908,11 @@ impl BoundOnClass {
}

/// Inner Salsa query for [`ProtocolClass::interface`].
#[salsa::tracked(cycle_initial=proto_interface_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
#[salsa::tracked(
cycle_initial=proto_interface_cycle_initial,
cycle_fn=proto_interface_cycle_recover,
heap_size=ruff_memory_usage::heap_size,
)]
fn cached_protocol_interface<'db>(
db: &'db dyn Db,
class: ClassType<'db>,
Expand Down Expand Up @@ -971,6 +1033,17 @@ fn proto_interface_cycle_initial<'db>(
ProtocolInterface::empty(db)
}

#[allow(clippy::trivially_copy_pass_by_ref)]
fn proto_interface_cycle_recover<'db>(
db: &'db dyn Db,
cycle: &salsa::Cycle,
previous: &ProtocolInterface<'db>,
value: ProtocolInterface<'db>,
_class: ClassType<'db>,
) -> ProtocolInterface<'db> {
value.cycle_normalized(db, *previous, cycle)
}

/// Bind `self`, and *also* discard the functionlike-ness of the callable.
///
/// This additional upcasting is required in order for protocols with `__call__` method
Expand Down
116 changes: 116 additions & 0 deletions crates/ty_python_semantic/src/types/signatures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,27 @@ impl<'db> CallableSignature<'db> {
}))
}

pub(crate) fn cycle_normalized(
&self,
db: &'db dyn Db,
previous: &Self,
cycle: &salsa::Cycle,
) -> Self {
if previous.overloads.len() == self.overloads.len() {
Self {
overloads: self
.overloads
.iter()
.zip(previous.overloads.iter())
.map(|(curr, prev)| curr.cycle_normalized(db, prev, cycle))
.collect(),
}
} else {
debug_assert_eq!(previous, &Self::bottom());
self.clone()
}
}

pub(super) fn recursive_type_normalized_impl(
&self,
db: &'db dyn Db,
Expand Down Expand Up @@ -528,6 +549,32 @@ impl<'db> Signature<'db> {
self
}

fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self {
let return_ty = self
.return_ty
.cycle_normalized(db, previous.return_ty, cycle);

let parameters = if self.parameters.len() == previous.parameters.len() {
Parameters::new(
db,
self.parameters
.iter()
.zip(previous.parameters.iter())
.map(|(curr, prev)| curr.cycle_normalized(db, prev, cycle)),
)
} else {
debug_assert_eq!(previous.parameters, Parameters::bottom());
self.parameters.clone()
};

Self {
generic_context: self.generic_context,
definition: self.definition,
parameters,
return_ty,
}
}

pub(super) fn recursive_type_normalized_impl(
&self,
db: &'db dyn Db,
Expand Down Expand Up @@ -2254,6 +2301,22 @@ impl<'db> Parameter<'db> {
}
}

fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self {
let annotated_type =
self.annotated_type
.cycle_normalized(db, previous.annotated_type, cycle);

let kind = self.kind.cycle_normalized(db, &previous.kind, cycle);

Self {
annotated_type,
inferred_annotation: self.inferred_annotation,
has_starred_annotation: self.has_starred_annotation,
kind,
form: self.form,
}
}

pub(super) fn recursive_type_normalized_impl(
&self,
db: &'db dyn Db,
Expand Down Expand Up @@ -2496,6 +2559,59 @@ pub enum ParameterKind<'db> {
}

impl<'db> ParameterKind<'db> {
#[expect(clippy::ref_option)]
fn cycle_normalized_default(
db: &'db dyn Db,
current: &Option<Type<'db>>,
previous: &Option<Type<'db>>,
cycle: &salsa::Cycle,
) -> Option<Type<'db>> {
match (current, previous) {
(Some(curr), Some(prev)) => Some(curr.cycle_normalized(db, *prev, cycle)),
(Some(curr), None) => Some(curr.recursive_type_normalized(db, cycle)),
(None, _) => *current,
}
}

fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self {
match (self, previous) {
(
ParameterKind::PositionalOnly { name, default_type },
ParameterKind::PositionalOnly {
default_type: prev_default,
..
},
) => ParameterKind::PositionalOnly {
name: name.clone(),
default_type: Self::cycle_normalized_default(db, default_type, prev_default, cycle),
},
(
ParameterKind::PositionalOrKeyword { name, default_type },
ParameterKind::PositionalOrKeyword {
default_type: prev_default,
..
},
) => ParameterKind::PositionalOrKeyword {
name: name.clone(),
default_type: Self::cycle_normalized_default(db, default_type, prev_default, cycle),
},
(
ParameterKind::KeywordOnly { name, default_type },
ParameterKind::KeywordOnly {
default_type: prev_default,
..
},
) => ParameterKind::KeywordOnly {
name: name.clone(),
default_type: Self::cycle_normalized_default(db, default_type, prev_default, cycle),
},
// Variadic / KeywordVariadic have no types to normalize.
// Also, if the current `ParameterKind` is different from `previous`, it means that `previous` is the cycle initial value,
// and the current value should take precedence.
_ => self.clone(),
}
}

fn apply_type_mapping_impl<'a>(
&self,
db: &'db dyn Db,
Expand Down
Loading