diff --git a/crates/ty_python_semantic/resources/corpus/cyclic_protocol.py b/crates/ty_python_semantic/resources/corpus/cyclic_protocol.py new file mode 100644 index 0000000000000..2843b4c560932 --- /dev/null +++ b/crates/ty_python_semantic/resources/corpus/cyclic_protocol.py @@ -0,0 +1,24 @@ +# Regression test for https://github.com/astral-sh/ty/issues/3080 + +# To reproduce the bug, deferred evaluation of type annotations must be applied. +from __future__ import annotations + +from typing import Generic, Protocol, Self, TypeVar, overload + +S = TypeVar("S") +T = TypeVar("T") + + +class Unit(Protocol): + def __mul__(self, other: S | Quantity[S]): ... + + +class Vector(Protocol): ... + + +class Quantity(Generic[T], Protocol): + @overload + def __mul__(self, other: Unit | Quantity[S]): ... + + @overload + def __mul__(self, other: Vector) -> Vector: ... diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 47358d0693fbc..ab2bebda10c4e 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -1136,7 +1136,12 @@ impl<'db> FunctionType<'db> { /// /// Were this not a salsa query, then the calling query /// would depend on the function's AST and rerun for every change in that file. - #[salsa::tracked(returns(ref), cycle_initial=|_, _, _| CallableSignature::single(Signature::bottom()), heap_size=ruff_memory_usage::heap_size)] + #[salsa::tracked( + returns(ref), + cycle_initial=|_, _, _| CallableSignature::single(Signature::bottom()), + cycle_fn=|db, cycle, previous, value: CallableSignature<'db>, _| value.cycle_normalized(db, previous, cycle), + heap_size=ruff_memory_usage::heap_size, + )] pub(crate) fn signature(self, db: &'db dyn Db) -> CallableSignature<'db> { self.updated_signature(db) .cloned() diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index 2438abd1f92ec..d749b2e55b1c6 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -256,6 +256,24 @@ impl<'db> ProtocolInterface<'db> { Self::new(db, BTreeMap::default()) } + fn cycle_normalized(self, db: &'db dyn Db, previous: Self, cycle: &salsa::Cycle) -> Self { + let prev_inner = previous.inner(db); + let curr_inner = self.inner(db); + + let members: BTreeMap<_, _> = curr_inner + .iter() + .map(|(name, curr_data)| { + let normalized = if let Some(prev_data) = prev_inner.get(name) { + curr_data.cycle_normalized(db, prev_data, cycle) + } else { + curr_data.clone() + }; + (name.clone(), normalized) + }) + .collect(); + Self::new(db, members) + } + pub(super) fn members<'a>( self, db: &'db dyn Db, @@ -404,6 +422,14 @@ pub(super) struct ProtocolMemberData<'db> { } impl<'db> ProtocolMemberData<'db> { + fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self { + Self { + kind: self.kind.cycle_normalized(db, &previous.kind, cycle), + qualifiers: self.qualifiers, + definition: self.definition, + } + } + fn recursive_type_normalized_impl( &self, db: &'db dyn Db, @@ -509,6 +535,38 @@ enum ProtocolMemberKind<'db> { } impl<'db> ProtocolMemberKind<'db> { + fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self { + match (self, previous) { + (Self::Method(curr), Self::Method(prev)) => { + debug_assert_eq!(curr.kind(db), prev.kind(db)); + let normalized = + curr.signatures(db) + .cycle_normalized(db, prev.signatures(db), cycle); + Self::Method(CallableType::new(db, normalized, curr.kind(db))) + } + (Self::Property(curr), Self::Property(prev)) => { + let getter = match (curr.getter(db), prev.getter(db)) { + (Some(curr), Some(prev)) => Some(curr.cycle_normalized(db, prev, cycle)), + (Some(curr), None) => Some(curr.recursive_type_normalized(db, cycle)), + (None, _) => None, + }; + let setter = match (curr.setter(db), prev.setter(db)) { + (Some(curr), Some(prev)) => Some(curr.cycle_normalized(db, prev, cycle)), + (Some(curr), None) => Some(curr.recursive_type_normalized(db, cycle)), + (None, _) => None, + }; + Self::Property(PropertyInstanceType::new(db, getter, setter)) + } + (Self::Other(curr), Self::Other(prev)) => { + Self::Other(curr.cycle_normalized(db, *prev, cycle)) + } + _ => { + debug_assert!(matches!(previous, Self::Other(ty) if ty.is_divergent())); + *self + } + } + } + fn apply_type_mapping_impl<'a>( &self, db: &'db dyn Db, @@ -850,7 +908,11 @@ impl BoundOnClass { } /// Inner Salsa query for [`ProtocolClass::interface`]. -#[salsa::tracked(cycle_initial=proto_interface_cycle_initial, heap_size=ruff_memory_usage::heap_size)] +#[salsa::tracked( + cycle_initial=proto_interface_cycle_initial, + cycle_fn=proto_interface_cycle_recover, + heap_size=ruff_memory_usage::heap_size, +)] fn cached_protocol_interface<'db>( db: &'db dyn Db, class: ClassType<'db>, @@ -971,6 +1033,17 @@ fn proto_interface_cycle_initial<'db>( ProtocolInterface::empty(db) } +#[allow(clippy::trivially_copy_pass_by_ref)] +fn proto_interface_cycle_recover<'db>( + db: &'db dyn Db, + cycle: &salsa::Cycle, + previous: &ProtocolInterface<'db>, + value: ProtocolInterface<'db>, + _class: ClassType<'db>, +) -> ProtocolInterface<'db> { + value.cycle_normalized(db, *previous, cycle) +} + /// Bind `self`, and *also* discard the functionlike-ness of the callable. /// /// This additional upcasting is required in order for protocols with `__call__` method diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 4b767029ccc67..0ae589665c09b 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -109,6 +109,27 @@ impl<'db> CallableSignature<'db> { })) } + pub(crate) fn cycle_normalized( + &self, + db: &'db dyn Db, + previous: &Self, + cycle: &salsa::Cycle, + ) -> Self { + if previous.overloads.len() == self.overloads.len() { + Self { + overloads: self + .overloads + .iter() + .zip(previous.overloads.iter()) + .map(|(curr, prev)| curr.cycle_normalized(db, prev, cycle)) + .collect(), + } + } else { + debug_assert_eq!(previous, &Self::bottom()); + self.clone() + } + } + pub(super) fn recursive_type_normalized_impl( &self, db: &'db dyn Db, @@ -528,6 +549,32 @@ impl<'db> Signature<'db> { self } + fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self { + let return_ty = self + .return_ty + .cycle_normalized(db, previous.return_ty, cycle); + + let parameters = if self.parameters.len() == previous.parameters.len() { + Parameters::new( + db, + self.parameters + .iter() + .zip(previous.parameters.iter()) + .map(|(curr, prev)| curr.cycle_normalized(db, prev, cycle)), + ) + } else { + debug_assert_eq!(previous.parameters, Parameters::bottom()); + self.parameters.clone() + }; + + Self { + generic_context: self.generic_context, + definition: self.definition, + parameters, + return_ty, + } + } + pub(super) fn recursive_type_normalized_impl( &self, db: &'db dyn Db, @@ -2254,6 +2301,22 @@ impl<'db> Parameter<'db> { } } + fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self { + let annotated_type = + self.annotated_type + .cycle_normalized(db, previous.annotated_type, cycle); + + let kind = self.kind.cycle_normalized(db, &previous.kind, cycle); + + Self { + annotated_type, + inferred_annotation: self.inferred_annotation, + has_starred_annotation: self.has_starred_annotation, + kind, + form: self.form, + } + } + pub(super) fn recursive_type_normalized_impl( &self, db: &'db dyn Db, @@ -2496,6 +2559,59 @@ pub enum ParameterKind<'db> { } impl<'db> ParameterKind<'db> { + #[expect(clippy::ref_option)] + fn cycle_normalized_default( + db: &'db dyn Db, + current: &Option>, + previous: &Option>, + cycle: &salsa::Cycle, + ) -> Option> { + match (current, previous) { + (Some(curr), Some(prev)) => Some(curr.cycle_normalized(db, *prev, cycle)), + (Some(curr), None) => Some(curr.recursive_type_normalized(db, cycle)), + (None, _) => *current, + } + } + + fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self { + match (self, previous) { + ( + ParameterKind::PositionalOnly { name, default_type }, + ParameterKind::PositionalOnly { + default_type: prev_default, + .. + }, + ) => ParameterKind::PositionalOnly { + name: name.clone(), + default_type: Self::cycle_normalized_default(db, default_type, prev_default, cycle), + }, + ( + ParameterKind::PositionalOrKeyword { name, default_type }, + ParameterKind::PositionalOrKeyword { + default_type: prev_default, + .. + }, + ) => ParameterKind::PositionalOrKeyword { + name: name.clone(), + default_type: Self::cycle_normalized_default(db, default_type, prev_default, cycle), + }, + ( + ParameterKind::KeywordOnly { name, default_type }, + ParameterKind::KeywordOnly { + default_type: prev_default, + .. + }, + ) => ParameterKind::KeywordOnly { + name: name.clone(), + default_type: Self::cycle_normalized_default(db, default_type, prev_default, cycle), + }, + // Variadic / KeywordVariadic have no types to normalize. + // Also, if the current `ParameterKind` is different from `previous`, it means that `previous` is the cycle initial value, + // and the current value should take precedence. + _ => self.clone(), + } + } + fn apply_type_mapping_impl<'a>( &self, db: &'db dyn Db,