Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions crates/ty_python_semantic/resources/mdtest/bidirectional.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,84 @@ def h[T](x: T, cond: bool) -> T | list[T]:
def i[T](x: T, cond: bool) -> T | list[T]:
return x if cond else [x]
```

## Type context sources

Type context is sourced from various places, including annotated assignments:

```py
from typing import Literal

a: list[Literal[1]] = [1]
```

Function parameter annotations:

```py
def b(x: list[Literal[1]]): ...

b([1])
```

Bound method parameter annotations:

```py
class C:
def __init__(self, x: list[Literal[1]]): ...
def foo(self, x: list[Literal[1]]): ...

C([1]).foo([1])
```

Declared variable types:

```py
d: list[Literal[1]]
d = [1]
```

Declared attribute types:

```py
class E:
e: list[Literal[1]]

def _(e: E):
# TODO: Implement attribute type context.
# error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to attribute `e` of type `list[Literal[1]]`"
e.e = [1]
```

Function return types:

```py
def f() -> list[Literal[1]]:
return [1]
```

## Class constructor parameters

```toml
[environment]
python-version = "3.12"
```

The parameters of both `__init__` and `__new__` are used as type context sources for constructor
calls:

```py
def f[T](x: T) -> list[T]:
return [x]

class A:
def __new__(cls, value: list[int | str]):
return super().__new__(cls, value)

def __init__(self, value: list[int | None]): ...

A(f(1))

# error: [invalid-argument-type] "Argument to function `__new__` is incorrect: Expected `list[int | str]`, found `list[list[Unknown]]`"
# error: [invalid-argument-type] "Argument to bound method `__init__` is incorrect: Expected `list[int | None]`, found `list[list[Unknown]]`"
A(f([]))
```
78 changes: 58 additions & 20 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5961,6 +5961,9 @@ impl<'db> Type<'db> {
/// Given a class literal or non-dynamic `SubclassOf` type, try calling it (creating an instance)
/// and return the resulting instance type.
///
/// The `infer_argument_types` closure should be invoked with the signatures of `__new__` and
/// `__init__`, such that the argument types can be inferred with the correct type context.
///
/// Models `type.__call__` behavior.
/// TODO: model metaclass `__call__`.
///
Expand All @@ -5971,10 +5974,10 @@ impl<'db> Type<'db> {
///
/// Foo()
/// ```
fn try_call_constructor(
fn try_call_constructor<'ast>(
self,
db: &'db dyn Db,
argument_types: CallArguments<'_, 'db>,
infer_argument_types: impl FnOnce(Option<Bindings<'db>>) -> CallArguments<'ast, 'db>,
tcx: TypeContext<'db>,
) -> Result<Type<'db>, ConstructorCallError<'db>> {
debug_assert!(matches!(
Expand Down Expand Up @@ -6030,11 +6033,63 @@ impl<'db> Type<'db> {
// easy to check if that's the one we found?
// Note that `__new__` is a static method, so we must inject the `cls` argument.
let new_method = self_type.lookup_dunder_new(db, ());

// Construct an instance type that we can use to look up the `__init__` instance method.
// This performs the same logic as `Type::to_instance`, except for generic class literals.
// TODO: we should use the actual return type of `__new__` to determine the instance type
let init_ty = self_type
.to_instance(db)
.expect("type should be convertible to instance type");

// Lookup the `__init__` instance method in the MRO.
let init_method = init_ty.member_lookup_with_policy(
db,
"__init__".into(),
MemberLookupPolicy::NO_INSTANCE_FALLBACK | MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK,
);

// Infer the call argument types, using both `__new__` and `__init__` for type-context.
let bindings = match (
new_method.as_ref().map(|method| &method.place),
&init_method.place,
) {
(Some(Place::Defined(new_method, ..)), Place::Undefined) => Some(
new_method
.bindings(db)
.map(|binding| binding.with_bound_type(self_type)),
),

(Some(Place::Undefined) | None, Place::Defined(init_method, ..)) => {
Some(init_method.bindings(db))
}

(Some(Place::Defined(new_method, ..)), Place::Defined(init_method, ..)) => {
let callable = UnionBuilder::new(db)
.add(*new_method)
.add(*init_method)
.build();

let new_method_bindings = new_method
.bindings(db)
.map(|binding| binding.with_bound_type(self_type));

Some(Bindings::from_union(
callable,
[new_method_bindings, init_method.bindings(db)],
))
}

_ => None,
};

let argument_types = infer_argument_types(bindings);

let new_call_outcome = new_method.and_then(|new_method| {
match new_method.place.try_call_dunder_get(db, self_type) {
Place::Defined(new_method, _, boundness) => {
let result =
new_method.try_call(db, argument_types.with_self(Some(self_type)).as_ref());

if boundness == Definedness::PossiblyUndefined {
Some(Err(DunderNewCallError::PossiblyUnbound(result.err())))
} else {
Expand All @@ -6045,24 +6100,7 @@ impl<'db> Type<'db> {
}
});

// Construct an instance type that we can use to look up the `__init__` instance method.
// This performs the same logic as `Type::to_instance`, except for generic class literals.
// TODO: we should use the actual return type of `__new__` to determine the instance type
let init_ty = self_type
.to_instance(db)
.expect("type should be convertible to instance type");

let init_call_outcome = if new_call_outcome.is_none()
|| !init_ty
.member_lookup_with_policy(
db,
"__init__".into(),
MemberLookupPolicy::NO_INSTANCE_FALLBACK
| MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK,
)
.place
.is_undefined()
{
let init_call_outcome = if new_call_outcome.is_none() || !init_method.is_undefined() {
Some(init_ty.try_call_dunder(db, "__init__", argument_types, tcx))
} else {
None
Expand Down
8 changes: 8 additions & 0 deletions crates/ty_python_semantic/src/types/call/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ impl<'db> Bindings<'db> {
self.elements.iter()
}

pub(crate) fn map(self, f: impl Fn(CallableBinding<'db>) -> CallableBinding<'db>) -> Self {
Self {
callable_type: self.callable_type,
argument_forms: self.argument_forms,
elements: self.elements.into_iter().map(f).collect(),
}
}

/// Match the arguments of a call site against the parameters of a collection of possibly
/// unioned, possibly overloaded signatures.
///
Expand Down
18 changes: 14 additions & 4 deletions crates/ty_python_semantic/src/types/infer/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6798,9 +6798,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.to_class_type(self.db())
.is_none_or(|enum_class| !class.is_subclass_of(self.db(), enum_class))
{
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
self.infer_argument_types(arguments, &mut call_arguments, &argument_forms);

if matches!(
class.known(self.db()),
Some(KnownClass::TypeVar | KnownClass::ExtensionsTypeVar)
Expand All @@ -6819,8 +6816,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}

let db = self.db();
let infer_call_arguments = |bindings: Option<Bindings<'db>>| {
if let Some(bindings) = bindings {
let bindings = bindings.match_parameters(self.db(), &call_arguments);
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings);
} else {
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
self.infer_argument_types(arguments, &mut call_arguments, &argument_forms);
}

call_arguments
};

return callable_type
.try_call_constructor(self.db(), call_arguments, tcx)
.try_call_constructor(db, infer_call_arguments, tcx)
.unwrap_or_else(|err| {
err.report_diagnostic(&self.context, callable_type, call_expression.into());
err.return_type()
Expand Down
Loading