diff --git a/crates/ty_python_semantic/resources/mdtest/dataclasses.md b/crates/ty_python_semantic/resources/mdtest/dataclasses.md index fcbe3f6c1b1206..a74c125b5dbe65 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclasses.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclasses.md @@ -56,8 +56,6 @@ Person(20, "Eve") ## Signature of `__init__` -TODO: All of the following tests are missing the `self` argument in the `__init__` signature. - Declarations in the class body are used to generate the signature of the `__init__` method. If the attributes are not just declarations, but also bindings, the type inferred from bindings is used as the default value. @@ -71,7 +69,7 @@ class D: y: str = "default" z: int | None = 1 + 2 -reveal_type(D.__init__) # revealed: (x: int, y: str = Literal["default"], z: int | None = Literal[3]) -> None +reveal_type(D.__init__) # revealed: (self: D, x: int, y: str = Literal["default"], z: int | None = Literal[3]) -> None ``` This also works if the declaration and binding are split: @@ -82,7 +80,7 @@ class D: x: int | None x = None -reveal_type(D.__init__) # revealed: (x: int | None = None) -> None +reveal_type(D.__init__) # revealed: (self: D, x: int | None = None) -> None ``` Non-fully static types are handled correctly: @@ -96,7 +94,7 @@ class C: y: int | Any z: tuple[int, Any] -reveal_type(C.__init__) # revealed: (x: Any, y: int | Any, z: tuple[int, Any]) -> None +reveal_type(C.__init__) # revealed: (self: C, x: Any, y: int | Any, z: tuple[int, Any]) -> None ``` Variables without annotations are ignored: @@ -107,7 +105,7 @@ class D: x: int y = 1 -reveal_type(D.__init__) # revealed: (x: int) -> None +reveal_type(D.__init__) # revealed: (self: D, x: int) -> None ``` If attributes without default values are declared after attributes with default values, a @@ -132,7 +130,7 @@ class D: y: ClassVar[str] = "default" z: bool -reveal_type(D.__init__) # revealed: (x: int, z: bool) -> None +reveal_type(D.__init__) # revealed: (self: D, x: int, z: bool) -> None d = D(1, True) reveal_type(d.x) # revealed: int @@ -150,7 +148,7 @@ class D: def y(self) -> str: return "" -reveal_type(D.__init__) # revealed: (x: int) -> None +reveal_type(D.__init__) # revealed: (self: D, x: int) -> None ``` And neither do nested class declarations: @@ -163,7 +161,7 @@ class D: class Nested: y: str -reveal_type(D.__init__) # revealed: (x: int) -> None +reveal_type(D.__init__) # revealed: (self: D, x: int) -> None ``` But if there is a variable annotation with a function or class literal type, the signature of @@ -181,7 +179,7 @@ class D: class_literal: TypeOf[SomeClass] class_subtype_of: type[SomeClass] -# revealed: (function_literal: def some_function() -> None, class_literal: , class_subtype_of: type[SomeClass]) -> None +# revealed: (self: D, function_literal: def some_function() -> None, class_literal: , class_subtype_of: type[SomeClass]) -> None reveal_type(D.__init__) ``` @@ -194,7 +192,7 @@ from typing import Callable class D: c: Callable[[int], str] -reveal_type(D.__init__) # revealed: (c: (int, /) -> str) -> None +reveal_type(D.__init__) # revealed: (self: D, c: (int, /) -> str) -> None ``` Implicit instance attributes do not affect the signature of `__init__`: @@ -209,7 +207,7 @@ class D: reveal_type(D(1).y) # revealed: str -reveal_type(D.__init__) # revealed: (x: int) -> None +reveal_type(D.__init__) # revealed: (self: D, x: int) -> None ``` Annotating expressions does not lead to an entry in `__annotations__` at runtime, and so it wouldn't @@ -222,7 +220,7 @@ class D: (x): int = 1 # TODO: should ideally not include a `x` parameter -reveal_type(D.__init__) # revealed: (x: int = Literal[1]) -> None +reveal_type(D.__init__) # revealed: (self: D, x: int = Literal[1]) -> None ``` ## `@dataclass` calls with arguments @@ -529,7 +527,7 @@ class C(Base): z: int = 10 x: int = 15 -reveal_type(C.__init__) # revealed: (x: int = Literal[15], y: int = Literal[0], z: int = Literal[10]) -> None +reveal_type(C.__init__) # revealed: (self: C, x: int = Literal[15], y: int = Literal[0], z: int = Literal[10]) -> None ``` ## Generic dataclasses @@ -582,7 +580,7 @@ class UppercaseString: class C: upper: UppercaseString = UppercaseString() -reveal_type(C.__init__) # revealed: (upper: str = str) -> None +reveal_type(C.__init__) # revealed: (self: C, upper: str = str) -> None c = C("abc") reveal_type(c.upper) # revealed: str @@ -628,7 +626,7 @@ class ConvertToLength: class C: converter: ConvertToLength = ConvertToLength() -reveal_type(C.__init__) # revealed: (converter: str = Literal[""]) -> None +reveal_type(C.__init__) # revealed: (self: C, converter: str = Literal[""]) -> None c = C("abc") reveal_type(c.converter) # revealed: int @@ -667,7 +665,7 @@ class AcceptsStrAndInt: class C: field: AcceptsStrAndInt = AcceptsStrAndInt() -reveal_type(C.__init__) # revealed: (field: str | int = int) -> None +reveal_type(C.__init__) # revealed: (self: C, field: str | int = int) -> None ``` ## `dataclasses.field` @@ -728,7 +726,7 @@ import dataclasses class C: x: str -reveal_type(C.__init__) # revealed: (x: str) -> None +reveal_type(C.__init__) # revealed: (self: C, x: str) -> None ``` ### Dataclass with custom `__init__` method @@ -821,10 +819,57 @@ reveal_type(Person.__mro__) # revealed: tuple[, None +reveal_type(Person.__init__) # revealed: (self: Person, name: str, age: int | None = None) -> None reveal_type(Person.__repr__) # revealed: def __repr__(self) -> str reveal_type(Person.__eq__) # revealed: def __eq__(self, value: object, /) -> bool ``` + +## Function-like behavior of synthesized methods + +Here, we make sure that the synthesized methods of dataclasses behave like proper functions. + +```toml +[environment] +python-version = "3.12" +``` + +```py +from dataclasses import dataclass +from typing import Callable +from types import FunctionType +from ty_extensions import CallableTypeOf, TypeOf, static_assert, is_subtype_of, is_assignable_to + +@dataclass +class C: + x: int + +reveal_type(C.__init__) # revealed: (self: C, x: int) -> None +reveal_type(type(C.__init__)) # revealed: + +# We can access attributes that are defined on functions: +reveal_type(type(C.__init__).__code__) # revealed: CodeType +reveal_type(C.__init__.__code__) # revealed: CodeType + +def equivalent_signature(self: C, x: int) -> None: + pass + +type DunderInitType = TypeOf[C.__init__] +type EquivalentPureCallableType = Callable[[C, int], None] +type EquivalentFunctionLikeCallableType = CallableTypeOf[equivalent_signature] + +static_assert(is_subtype_of(DunderInitType, EquivalentPureCallableType)) +static_assert(is_assignable_to(DunderInitType, EquivalentPureCallableType)) + +static_assert(not is_subtype_of(EquivalentPureCallableType, DunderInitType)) +static_assert(not is_assignable_to(EquivalentPureCallableType, DunderInitType)) + +static_assert(is_subtype_of(DunderInitType, EquivalentFunctionLikeCallableType)) +static_assert(is_assignable_to(DunderInitType, EquivalentFunctionLikeCallableType)) + +static_assert(not is_subtype_of(EquivalentFunctionLikeCallableType, DunderInitType)) +static_assert(not is_assignable_to(EquivalentFunctionLikeCallableType, DunderInitType)) + +static_assert(is_subtype_of(DunderInitType, FunctionType)) +``` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index c21144b719c754..af0ce426c3dd6c 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1258,6 +1258,14 @@ impl<'db> Type<'db> { ) => (self.literal_fallback_instance(db)) .is_some_and(|instance| instance.is_subtype_of(db, target)), + // Function-like callables are subtypes of `FunctionType` + (Type::Callable(callable), Type::NominalInstance(target)) + if callable.is_function_like(db) + && target.class.is_known(db, KnownClass::FunctionType) => + { + true + } + (Type::FunctionLiteral(self_function_literal), Type::Callable(_)) => { self_function_literal .into_callable_type(db) @@ -2766,6 +2774,26 @@ impl<'db> Type<'db> { instance.display(db), owner.display(db) ); + match self { + Type::Callable(callable) if callable.is_function_like(db) => { + // For "function-like" callables, model the the behavior of `FunctionType.__get__`. + // + // It is a shortcut to model this in `try_call_dunder_get`. If we want to be really precise, + // we should instead return a new method-wrapper type variant for the synthesized `__get__` + // method of these synthesized functions. The method-wrapper would then be returned from + // `find_name_in_mro` when called on function-like `Callable`s. This would allow us to + // correctly model the behavior of *explicit* `SomeDataclass.__init__.__get__` calls. + return if instance.is_none(db) { + Some((self, AttributeKind::NormalOrNonDataDescriptor)) + } else { + Some(( + Type::Callable(callable.bind_self(db)), + AttributeKind::NormalOrNonDataDescriptor, + )) + }; + } + _ => {} + } let descr_get = self.class_member(db, "__get__".into()).symbol; @@ -3099,6 +3127,11 @@ impl<'db> Type<'db> { Type::Callable(_) | Type::DataclassTransformer(_) if name_str == "__call__" => { Symbol::bound(self).into() } + + Type::Callable(callable) if callable.is_function_like(db) => KnownClass::FunctionType + .to_instance(db) + .member_lookup_with_policy(db, name, policy), + Type::Callable(_) | Type::DataclassTransformer(_) => KnownClass::Object .to_instance(db) .member_lookup_with_policy(db, name, policy), @@ -5127,6 +5160,9 @@ impl<'db> Type<'db> { Type::MethodWrapper(_) => KnownClass::MethodWrapperType.to_class_literal(db), Type::WrapperDescriptor(_) => KnownClass::WrapperDescriptorType.to_class_literal(db), Type::DataclassDecorator(_) => KnownClass::FunctionType.to_class_literal(db), + Type::Callable(callable) if callable.is_function_like(db) => { + KnownClass::FunctionType.to_class_literal(db) + } Type::Callable(_) | Type::DataclassTransformer(_) => KnownClass::Type.to_instance(db), Type::ModuleLiteral(_) => KnownClass::ModuleType.to_class_literal(db), Type::Tuple(_) => KnownClass::Tuple.to_class_literal(db), @@ -6905,6 +6941,7 @@ impl<'db> FunctionType<'db> { Type::Callable(CallableType::from_overloads( db, self.signature(db).overloads.iter().cloned(), + false, )) } @@ -7531,6 +7568,7 @@ impl<'db> BoundMethodType<'db> { .overloads .iter() .map(signatures::Signature::bind_self), + false, )) } @@ -7595,12 +7633,23 @@ impl<'db> BoundMethodType<'db> { pub struct CallableType<'db> { #[returns(deref)] signatures: Box<[Signature<'db>]>, + /// We use `CallableType` to represent function-like objects, like the synthesized methods + /// of dataclasses or NamedTuples. These callables act like real functions when accessed + /// as attributes on instances, i.e. they bind `self`. + is_function_like: bool, } impl<'db> CallableType<'db> { /// Create a non-overloaded callable type with a single signature. pub(crate) fn single(db: &'db dyn Db, signature: Signature<'db>) -> Self { - CallableType::new(db, vec![signature].into_boxed_slice()) + CallableType::new(db, vec![signature].into_boxed_slice(), false) + } + + /// Create a non-overloaded, function-like callable type with a single signature. + /// + /// A function-like callable will bind `self` when accessed as an attribute on an instance. + pub(crate) fn function_like(db: &'db dyn Db, signature: Signature<'db>) -> Self { + CallableType::new(db, vec![signature].into_boxed_slice(), true) } /// Create an overloaded callable type with multiple signatures. @@ -7608,7 +7657,7 @@ impl<'db> CallableType<'db> { /// # Panics /// /// Panics if `overloads` is empty. - pub(crate) fn from_overloads(db: &'db dyn Db, overloads: I) -> Self + pub(crate) fn from_overloads(db: &'db dyn Db, overloads: I, is_function_like: bool) -> Self where I: IntoIterator>, { @@ -7617,7 +7666,7 @@ impl<'db> CallableType<'db> { !overloads.is_empty(), "CallableType must have at least one signature" ); - CallableType::new(db, overloads) + CallableType::new(db, overloads, is_function_like) } /// Create a callable type which accepts any parameters and returns an `Unknown` type. @@ -7628,6 +7677,14 @@ impl<'db> CallableType<'db> { ) } + pub(crate) fn bind_self(self, db: &'db dyn Db) -> Self { + CallableType::from_overloads( + db, + self.signatures(db).iter().map(Signature::bind_self), + false, + ) + } + /// Create a callable type which represents a fully-static "bottom" callable. /// /// Specifically, this represents a callable type with a single signature: @@ -7646,6 +7703,7 @@ impl<'db> CallableType<'db> { self.signatures(db) .iter() .map(|signature| signature.normalized(db)), + self.is_function_like(db), ) } @@ -7655,6 +7713,7 @@ impl<'db> CallableType<'db> { self.signatures(db) .iter() .map(|signature| signature.apply_type_mapping(db, type_mapping)), + self.is_function_like(db), ) } @@ -7703,6 +7762,13 @@ impl<'db> CallableType<'db> { where F: Fn(&Signature<'db>, &Signature<'db>) -> bool, { + let self_is_function_like = self.is_function_like(db); + let other_is_function_like = other.is_function_like(db); + + if !self_is_function_like && other_is_function_like { + return false; + } + match (self.signatures(db), other.signatures(db)) { ([self_signature], [other_signature]) => { // Base case: both callable types contain a single signature. @@ -7745,6 +7811,10 @@ impl<'db> CallableType<'db> { /// /// See [`Type::is_equivalent_to`] for more details. fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool { + if self.is_function_like(db) != other.is_function_like(db) { + return false; + } + match (self.signatures(db), other.signatures(db)) { ([self_signature], [other_signature]) => { // Common case: both callable types contain a single signature, use the custom @@ -7771,6 +7841,10 @@ impl<'db> CallableType<'db> { /// /// See [`Type::is_gradual_equivalent_to`] for more details. fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool { + if self.is_function_like(db) != other.is_function_like(db) { + return false; + } + match (self.signatures(db), other.signatures(db)) { ([self_signature], [other_signature]) => { self_signature.is_gradual_equivalent_to(db, other_signature) @@ -7790,6 +7864,7 @@ impl<'db> CallableType<'db> { .iter() .cloned() .map(|signature| signature.replace_self_reference(db, class)), + self.is_function_like(db), ) } } diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 83870384ee971f..cc482bf0b92fb9 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1260,7 +1260,7 @@ impl<'db> ClassLiteral<'db> { } let signature = Signature::new(Parameters::new(parameters), Some(Type::none(db))); - Some(Type::Callable(CallableType::single(db, signature))) + Some(Type::Callable(CallableType::function_like(db, signature))) }; match (field_policy, name) { @@ -1274,7 +1274,13 @@ impl<'db> ClassLiteral<'db> { return None; } - signature_from_fields(vec![]) + let self_parameter = Parameter::positional_or_keyword(Name::new_static("self")) + // TODO: could be `Self`. + .with_annotated_type(Type::instance( + db, + self.apply_optional_specialization(db, specialization), + )); + signature_from_fields(vec![self_parameter]) } (CodeGeneratorKind::NamedTuple, "__new__") => { let cls_parameter = Parameter::positional_or_keyword(Name::new_static("cls")) @@ -1287,16 +1293,24 @@ impl<'db> ClassLiteral<'db> { } let signature = Signature::new( - Parameters::new([Parameter::positional_or_keyword(Name::new_static("other")) - // TODO: could be `Self`. - .with_annotated_type(Type::instance( - db, - self.apply_optional_specialization(db, specialization), - ))]), + Parameters::new([ + Parameter::positional_or_keyword(Name::new_static("self")) + // TODO: could be `Self`. + .with_annotated_type(Type::instance( + db, + self.apply_optional_specialization(db, specialization), + )), + Parameter::positional_or_keyword(Name::new_static("other")) + // TODO: could be `Self`. + .with_annotated_type(Type::instance( + db, + self.apply_optional_specialization(db, specialization), + )), + ]), Some(KnownClass::Bool.to_instance(db)), ); - Some(Type::Callable(CallableType::single(db, signature))) + Some(Type::Callable(CallableType::function_like(db, signature))) } (CodeGeneratorKind::NamedTuple, name) if name != "__init__" => { KnownClass::NamedTupleFallback diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 9f3f48974525d8..23f7c8adac1844 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -8712,6 +8712,7 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Callable(CallableType::from_overloads( db, std::iter::once(signature).chain(signature_iter), + false, )) } },