diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md index e101384ce6a4f1..380452f6c26f82 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md @@ -454,6 +454,28 @@ reveal_type(d.method3()) # revealed: SomeProtocol[int] reveal_type(d.method3().x) # revealed: int ``` +When a method is overloaded, the specialization is applied to all overloads. + +```py +from typing import overload, Generic, TypeVar + +S = TypeVar("S") + +class WithOverloadedMethod(Generic[T]): + @overload + def method(self, x: T) -> T: + return x + + @overload + def method(self, x: S) -> S | T: + return x + + def method(self, x: S | T) -> S | T: + return x + +reveal_type(WithOverloadedMethod[int].method) # revealed: Overload[(self, x: int) -> int, (self, x: S) -> S | int] +``` + ## Cyclic class definitions ### F-bounded quantification diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md index b342137e7ffa89..c5d4772d362f32 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md @@ -347,6 +347,26 @@ reveal_type(c.method2()) # revealed: str reveal_type(c.method3()) # revealed: LinkedList[int] ``` +When a method is overloaded, the specialization is applied to all overloads. + +```py +from typing import overload + +class WithOverloadedMethod[T]: + @overload + def method(self, x: T) -> T: + return x + + @overload + def method[S](self, x: S) -> S | T: + return x + + def method[S](self, x: S | T) -> S | T: + return x + +reveal_type(WithOverloadedMethod[int].method) # revealed: Overload[(self, x: int) -> int, (self, x: S) -> S | int] +``` + ## Cyclic class definitions ### F-bounded quantification diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 9dc4b4cc6d3669..a679172c53a8c5 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -3416,16 +3416,10 @@ impl<'db> Type<'db> { Type::BoundMethod(bound_method) => { let signature = bound_method.function(db).signature(db); - Signatures::single(match signature { - FunctionSignature::Single(signature) => { - CallableSignature::single(self, signature.clone()) - .with_bound_type(bound_method.self_instance(db)) - } - FunctionSignature::Overloaded(signatures, _) => { - CallableSignature::from_overloads(self, signatures.iter().cloned()) - .with_bound_type(bound_method.self_instance(db)) - } - }) + Signatures::single( + CallableSignature::from_overloads(self, signature.overloads.iter().cloned()) + .with_bound_type(bound_method.self_instance(db)), + ) } Type::MethodWrapper( @@ -3783,14 +3777,7 @@ impl<'db> Type<'db> { Signatures::single(signature) } - _ => Signatures::single(match function_type.signature(db) { - FunctionSignature::Single(signature) => { - CallableSignature::single(self, signature.clone()) - } - FunctionSignature::Overloaded(signatures, _) => { - CallableSignature::from_overloads(self, signatures.iter().cloned()) - } - }), + _ => Signatures::single(function_type.signature(db).overloads.clone()), }, Type::ClassLiteral(class) => match class.known(db) { @@ -6559,46 +6546,21 @@ bitflags! { } } -/// A function signature, which can be either a single signature or an overloaded signature. +/// A function signature, which optionally includes an implementation signature if the function is +/// overloaded. #[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update)] -pub(crate) enum FunctionSignature<'db> { - /// A single function signature. - Single(Signature<'db>), - - /// An overloaded function signature containing the `@overload`-ed signatures and an optional - /// implementation signature. - Overloaded(Vec>, Option>), +pub(crate) struct FunctionSignature<'db> { + pub(crate) overloads: CallableSignature<'db>, + pub(crate) implementation: Option>, } impl<'db> FunctionSignature<'db> { - /// Returns a slice of all signatures. - /// - /// For an overloaded function, this only includes the `@overload`-ed signatures and not the - /// implementation signature. - pub(crate) fn as_slice(&self) -> &[Signature<'db>] { - match self { - Self::Single(signature) => std::slice::from_ref(signature), - Self::Overloaded(signatures, _) => signatures, - } - } - - /// Returns an iterator over the signatures. - pub(crate) fn iter(&self) -> Iter> { - self.as_slice().iter() - } - /// Returns the "bottom" signature (subtype of all fully-static signatures.) pub(crate) fn bottom(db: &'db dyn Db) -> Self { - Self::Single(Signature::bottom(db)) - } -} - -impl<'db> IntoIterator for &'db FunctionSignature<'db> { - type Item = &'db Signature<'db>; - type IntoIter = Iter<'db, Signature<'db>>; - - fn into_iter(self) -> Self::IntoIter { - self.iter() + FunctionSignature { + overloads: CallableSignature::single(Type::any(), Signature::bottom(db)), + implementation: None, + } } } @@ -6669,7 +6631,7 @@ impl<'db> FunctionType<'db> { pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> { Type::Callable(CallableType::from_overloads( db, - self.signature(db).iter().cloned(), + self.signature(db).overloads.iter().cloned(), )) } @@ -6737,20 +6699,32 @@ impl<'db> FunctionType<'db> { /// would depend on the function's AST and rerun for every change in that file. #[salsa::tracked(returns(ref), cycle_fn=signature_cycle_recover, cycle_initial=signature_cycle_initial)] pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> { + let specialization = self.specialization(db); if let Some(overloaded) = self.to_overloaded(db) { - FunctionSignature::Overloaded( - overloaded - .overloads - .iter() - .copied() - .map(|overload| overload.internal_signature(db)) - .collect(), - overloaded - .implementation - .map(|implementation| implementation.internal_signature(db)), - ) + FunctionSignature { + overloads: CallableSignature::from_overloads( + Type::FunctionLiteral(self), + overloaded.overloads.iter().copied().map(|overload| { + overload + .internal_signature(db) + .apply_optional_specialization(db, specialization) + }), + ), + implementation: overloaded.implementation.map(|implementation| { + implementation + .internal_signature(db) + .apply_optional_specialization(db, specialization) + }), + } } else { - FunctionSignature::Single(self.internal_signature(db)) + FunctionSignature { + overloads: CallableSignature::single( + Type::FunctionLiteral(self), + self.internal_signature(db) + .apply_optional_specialization(db, specialization), + ), + implementation: None, + } } } @@ -6772,17 +6746,13 @@ impl<'db> FunctionType<'db> { let index = semantic_index(db, scope.file(db)); GenericContext::from_type_params(db, index, type_params) }); - let mut signature = Signature::from_function( + Signature::from_function( db, generic_context, self.inherited_generic_context(db), definition, function_stmt_node, - ); - if let Some(specialization) = self.specialization(db) { - signature = signature.apply_specialization(db, specialization); - } - signature + ) } pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool { @@ -6852,7 +6822,7 @@ impl<'db> FunctionType<'db> { typevars: &mut FxOrderSet>, ) { let signatures = self.signature(db); - for signature in signatures { + for signature in &signatures.overloads { signature.find_legacy_typevars(db, typevars); } } @@ -7112,6 +7082,7 @@ impl<'db> BoundMethodType<'db> { db, self.function(db) .signature(db) + .overloads .iter() .map(signatures::Signature::bind_self), )) diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index b5bc35e0e2a2a5..d0716a6459b9f0 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -10,9 +10,9 @@ use crate::types::class::{ClassLiteral, ClassType, GenericAlias}; use crate::types::generics::{GenericContext, Specialization}; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::{ - CallableType, FunctionSignature, IntersectionType, KnownClass, MethodWrapperKind, Protocol, - StringLiteralType, SubclassOfInner, Type, TypeVarBoundOrConstraints, TypeVarInstance, - UnionType, WrapperDescriptorKind, + CallableType, IntersectionType, KnownClass, MethodWrapperKind, Protocol, StringLiteralType, + SubclassOfInner, Type, TypeVarBoundOrConstraints, TypeVarInstance, UnionType, + WrapperDescriptorKind, }; use crate::{Db, FxOrderSet}; @@ -118,8 +118,8 @@ impl Display for DisplayRepresentation<'_> { // the generic type parameters to the signature, i.e. // show `def foo[T](x: T) -> T`. - match signature { - FunctionSignature::Single(signature) => { + match signature.overloads.as_slice() { + [signature] => { write!( f, // "def {name}{specialization}{signature}", @@ -128,7 +128,7 @@ impl Display for DisplayRepresentation<'_> { signature = signature.display(self.db) ) } - FunctionSignature::Overloaded(signatures, _) => { + signatures => { // TODO: How to display overloads? f.write_str("Overload[")?; let mut join = f.join(", "); @@ -146,8 +146,8 @@ impl Display for DisplayRepresentation<'_> { // TODO: use the specialization from the method. Similar to the comment above // about the function specialization, - match function.signature(self.db) { - FunctionSignature::Single(signature) => { + match function.signature(self.db).overloads.as_slice() { + [signature] => { write!( f, "bound method {instance}.{method}{signature}", @@ -156,7 +156,7 @@ impl Display for DisplayRepresentation<'_> { signature = signature.bind_self().display(self.db) ) } - FunctionSignature::Overloaded(signatures, _) => { + signatures => { // TODO: How to display overloads? f.write_str("Overload[")?; let mut join = f.join(", "); diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 7cbead0b751c73..25169d3f570513 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -195,6 +195,10 @@ impl<'db> CallableSignature<'db> { self.overloads.iter() } + pub(crate) fn as_slice(&self) -> &[Signature<'db>] { + self.overloads.as_slice() + } + fn replace_callable_type(&mut self, before: Type<'db>, after: Type<'db>) { if self.callable_type == before { self.callable_type = after; @@ -309,12 +313,16 @@ impl<'db> Signature<'db> { } } - pub(crate) fn apply_specialization( - &self, + pub(crate) fn apply_optional_specialization( + self, db: &'db dyn Db, - specialization: Specialization<'db>, + specialization: Option>, ) -> Self { - self.apply_type_mapping(db, specialization.type_mapping()) + if let Some(specialization) = specialization { + self.apply_type_mapping(db, specialization.type_mapping()) + } else { + self + } } pub(crate) fn apply_type_mapping<'a>( @@ -1743,7 +1751,10 @@ mod tests { // With no decorators, internal and external signature are the same assert_eq!( func.signature(&db), - &FunctionSignature::Single(expected_sig) + &FunctionSignature { + overloads: CallableSignature::single(Type::FunctionLiteral(func), expected_sig), + implementation: None + }, ); } }