Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion crates/ty_python_semantic/resources/mdtest/call/overloads.md
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def _(t: tuple[int, str] | tuple[int, str, int]) -> None:
f(*t) # error: [no-matching-overload]
```

## Filtering based on variaidic arguments
## Filtering based on variadic arguments

This is step 4 of the overload call evaluation algorithm which specifies that:

Expand Down Expand Up @@ -1469,6 +1469,79 @@ def _(arg: list[Any]):
reveal_type(f4(*arg)) # revealed: Unknown
```

### Varidic argument with generics

`overloaded.pyi`:

```pyi
from typing import Any, TypeVar, overload

T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")

@overload
def f1(x: T1, /) -> tuple[T1]: ...
@overload
def f1(x1: T1, x2: T2, /) -> tuple[T1, T2]: ...
@overload
def f1(x1: T1, x2: T2, x3: T3, /) -> tuple[T1, T2, T3]: ...
@overload
def f1(*args: Any) -> tuple[Any, ...]: ...

@overload
def f2(x1: T1) -> tuple[T1]: ...
@overload
def f2(x1: T1, x2: T2) -> tuple[T1, T2]: ...
@overload
def f2(*args: Any, **kwargs: Any) -> tuple[Any, ...]: ...

@overload
def f3(x: T1) -> tuple[T1]: ...
@overload
def f3(x1: T1, x2: T2) -> tuple[T1, T2]: ...
@overload
def f3(*args: Any) -> tuple[Any, ...]: ...
@overload
def f3(**kwargs: Any) -> dict[str, Any]: ...
```

```py
from overloaded import f1, f2, f3
from typing import Any

# These calls only match the last overload
reveal_type(f1()) # revealed: tuple[Any, ...]
reveal_type(f1(1, 2, 3, 4)) # revealed: tuple[Any, ...]

# While these calls match multiple overloads but step 5 filters out all the remaining overloads
# except the most specific one in terms of the number of arguments.
reveal_type(f1(1)) # revealed: tuple[Literal[1]]
reveal_type(f1(1, 2)) # revealed: tuple[Literal[1], Literal[2]]
reveal_type(f1(1, 2, 3)) # revealed: tuple[Literal[1], Literal[2], Literal[3]]

def _(args1: list[int], args2: list[Any]):
reveal_type(f1(*args1)) # revealed: tuple[Any, ...]
reveal_type(f1(*args2)) # revealed: tuple[Any, ...]

reveal_type(f2()) # revealed: tuple[Any, ...]
reveal_type(f2(1, 2)) # revealed: tuple[Literal[1], Literal[2]]
# TODO: Should be `tuple[Literal[1], Literal[2]]`
reveal_type(f2(x1=1, x2=2)) # revealed: Unknown
# TODO: Should be `tuple[Literal[2], Literal[1]]`
reveal_type(f2(x2=1, x1=2)) # revealed: Unknown
reveal_type(f2(1, 2, z=3)) # revealed: tuple[Any, ...]

reveal_type(f3(1, 2)) # revealed: tuple[Literal[1], Literal[2]]
reveal_type(f3(1, 2, 3)) # revealed: tuple[Any, ...]
# TODO: Should be `tuple[Literal[1], Literal[2]]`
reveal_type(f3(x1=1, x2=2)) # revealed: Unknown
reveal_type(f3(z=1)) # revealed: dict[str, Any]

# error: [no-matching-overload]
reveal_type(f3(1, 2, x=3)) # revealed: Unknown
```

### Non-participating fully-static parameter

Ref: <https://github.com/astral-sh/ty/issues/552#issuecomment-2969052173>
Expand Down
2 changes: 1 addition & 1 deletion crates/ty_python_semantic/src/types/call/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ impl<'a, 'db> CallArguments<'a, 'db> {
if expansion_size > MAX_EXPANSIONS {
tracing::debug!(
"Skipping argument type expansion as it would exceed the \
maximum number of expansions ({MAX_EXPANSIONS})"
maximum number of expansions ({MAX_EXPANSIONS})"
);
return Some(State::LimitReached(index));
}
Expand Down
89 changes: 89 additions & 0 deletions crates/ty_python_semantic/src/types/call/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
//! arguments against the parameters of the callable. Like with
//! [signatures][crate::types::signatures], we have to handle the fact that the callable might be a
//! union of types, each of which might contain multiple overloads.
//!
//! ### Tracing
//!
//! This module is instrumented with debug-level `tracing` messages. You can set the `TY_LOG`
//! environment variable to see this output when testing locally. `tracing` log messages typically
//! have a `target` field, which is the name of the module the message appears in — in this case,
//! `ty_python_semantic::types::call::bind`.

use std::borrow::Cow;
use std::collections::HashSet;
Expand Down Expand Up @@ -1582,6 +1589,13 @@ impl<'db> CallableBinding<'db> {
// before checking.
let argument_types = argument_types.with_self(self.bound_type);

tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 1",
);
Comment on lines +1592 to +1597
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature is important because we want to know which message is this for but this will be too verbose for functions containing multiple overloads. Maybe, it might be better to add some information about the call site instead like which call expression this is for.


// Step 1: Check the result of the arity check which is done by `match_parameters`
let matching_overload_indexes = match self.matching_overload_index() {
MatchingOverloadIndex::None => {
Expand Down Expand Up @@ -1612,6 +1626,13 @@ impl<'db> CallableBinding<'db> {
overload.check_types(db, argument_types.as_ref(), call_expression_tcx);
}

tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 2",
);

match self.matching_overload_index() {
MatchingOverloadIndex::None => {
// If all overloads result in errors, proceed to step 3.
Expand All @@ -1624,6 +1645,13 @@ impl<'db> CallableBinding<'db> {
// If two or more candidate overloads remain, proceed to step 4.
self.filter_overloads_containing_variadic(&indexes);

tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 4",
);

match self.matching_overload_index() {
MatchingOverloadIndex::None => {
// This shouldn't be possible because step 4 can only filter out overloads
Expand All @@ -1642,6 +1670,13 @@ impl<'db> CallableBinding<'db> {
argument_types.as_ref(),
&indexes,
);

tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 5",
);
}
}

Expand Down Expand Up @@ -1744,12 +1779,26 @@ impl<'db> CallableBinding<'db> {
overload.match_parameters(db, expanded_arguments, &mut argument_forms);
}

tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 1",
);

merged_argument_forms.merge(&argument_forms);

for (_, overload) in self.matching_overloads_mut() {
overload.check_types(db, expanded_arguments, call_expression_tcx);
}

tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 2",
);

let return_type = match self.matching_overload_index() {
MatchingOverloadIndex::None => None,
MatchingOverloadIndex::Single(index) => {
Expand All @@ -1758,6 +1807,13 @@ impl<'db> CallableBinding<'db> {
MatchingOverloadIndex::Multiple(matching_overload_indexes) => {
self.filter_overloads_containing_variadic(&matching_overload_indexes);

tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 4",
);

match self.matching_overload_index() {
MatchingOverloadIndex::None => {
tracing::debug!(
Expand All @@ -1772,6 +1828,14 @@ impl<'db> CallableBinding<'db> {
expanded_arguments,
&indexes,
);

tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 5",
);

Some(self.return_type())
}
}
Expand Down Expand Up @@ -1926,12 +1990,37 @@ impl<'db> CallableBinding<'db> {
.take(max_parameter_count)
.collect::<Vec<_>>();

// The following loop is trying to construct a tuple of argument types that correspond to
// the participating parameter indexes. Considering the following example:
//
// ```python
// @overload
// def f(x: Literal[1], y: Literal[2]) -> tuple[int, int]: ...
// @overload
// def f(*args: Any) -> tuple[Any, ...]: ...
//
// f(1, 2)
// ```
//
// Here, only the first parameter participates in the filtering process because only one
// overload has the second parameter. So, while going through the argument types, the
// second argument needs to be skipped but for the second overload both arguments map to
// the first parameter and that parameter is considered for the filtering process. This
// flag is to handle that special case of many-to-one mapping from arguments to parameters.
let mut variadic_parameter_handled = false;

for (argument_index, argument_type) in arguments.iter_types().enumerate() {
if variadic_parameter_handled {
continue;
}
for overload_index in matching_overload_indexes {
let overload = &self.overloads[*overload_index];
for (parameter_index, variadic_argument_type) in
overload.argument_matches[argument_index].iter()
{
if overload.signature.parameters()[parameter_index].is_variadic() {
variadic_parameter_handled = true;
}
if !participating_parameter_indexes.contains(&parameter_index) {
continue;
}
Expand Down
Loading