Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions crates/ty_python_semantic/resources/mdtest/subscript/tuple.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

## Indexing

```toml
[environment]
python-version = "3.11"
```

```py
t = (1, "a", "b")

Expand All @@ -20,6 +25,148 @@ b = t[-4] # error: [index-out-of-bounds]
reveal_type(b) # revealed: Unknown
```

Precise types for index operations are also inferred for tuple subclasses:

```py
class I0: ...
class I1: ...
class I2: ...
class I3: ...
class I5: ...
class HeterogeneousSubclass0(tuple[()]): ...

# revealed: Overload[(self, index: SupportsIndex, /) -> Never, (self, index: slice[Any, Any, Any], /) -> tuple[()]]
reveal_type(HeterogeneousSubclass0.__getitem__)

def f0(h0: HeterogeneousSubclass0, i: int):
reveal_type(h0[0]) # revealed: Never
reveal_type(h0[1]) # revealed: Never
reveal_type(h0[-1]) # revealed: Never
reveal_type(h0[i]) # revealed: Never

class HeterogeneousSubclass1(tuple[I0]): ...

# revealed: Overload[(self, index: SupportsIndex, /) -> I0, (self, index: slice[Any, Any, Any], /) -> tuple[I0, ...]]
reveal_type(HeterogeneousSubclass1.__getitem__)

def f0(h1: HeterogeneousSubclass1, i: int):
reveal_type(h1[0]) # revealed: I0
reveal_type(h1[1]) # revealed: I0
reveal_type(h1[-1]) # revealed: I0
reveal_type(h1[i]) # revealed: I0

# Element at index 2 is deliberately the same as the element at index 1,
# to illustrate that the `__getitem__` overloads for these two indices are combined
class HeterogeneousSubclass4(tuple[I0, I1, I0, I3]): ...

# revealed: Overload[(self, index: Literal[-4, -2, 0, 2], /) -> I0, (self, index: Literal[-3, 1], /) -> I1, (self, index: Literal[-1, 3], /) -> I3, (self, index: SupportsIndex, /) -> I0 | I1 | I3, (self, index: slice[Any, Any, Any], /) -> tuple[I0 | I1 | I3, ...]]
reveal_type(HeterogeneousSubclass4.__getitem__)

def f(h4: HeterogeneousSubclass4, i: int):
reveal_type(h4[0]) # revealed: I0
reveal_type(h4[1]) # revealed: I1
reveal_type(h4[2]) # revealed: I0
reveal_type(h4[3]) # revealed: I3
reveal_type(h4[-1]) # revealed: I3
reveal_type(h4[-2]) # revealed: I0
reveal_type(h4[-3]) # revealed: I1
reveal_type(h4[-4]) # revealed: I0
reveal_type(h4[i]) # revealed: I0 | I1 | I3

class MixedSubclass(tuple[I0, *tuple[I1, ...], I2, I3, I2, I5]): ...

# revealed: Overload[(self, index: Literal[0], /) -> I0, (self, index: Literal[2, 3], /) -> I1 | I2 | I3, (self, index: Literal[-1], /) -> I5, (self, index: Literal[1], /) -> I1 | I2, (self, index: Literal[-3], /) -> I3, (self, index: Literal[-5], /) -> I1 | I0, (self, index: Literal[-4, -2], /) -> I2, (self, index: Literal[4], /) -> I1 | I2 | I3 | I5, (self, index: SupportsIndex, /) -> I0 | I1 | I2 | I3 | I5, (self, index: slice[Any, Any, Any], /) -> tuple[I0 | I1 | I2 | I3 | I5, ...]]
reveal_type(MixedSubclass.__getitem__)

def g(m: MixedSubclass, i: int):
reveal_type(m[0]) # revealed: I0
reveal_type(m[1]) # revealed: I1 | I2
reveal_type(m[2]) # revealed: I1 | I2 | I3
reveal_type(m[3]) # revealed: I1 | I2 | I3
reveal_type(m[4]) # revealed: I1 | I2 | I3 | I5

reveal_type(m[-1]) # revealed: I5
reveal_type(m[-2]) # revealed: I2
reveal_type(m[-3]) # revealed: I3
reveal_type(m[-4]) # revealed: I2
reveal_type(m[-5]) # revealed: I1 | I0

reveal_type(m[i]) # revealed: I0 | I1 | I2 | I3 | I5

# Ideally we would not include `I0` in the unions for these,
# but it's not possible to do this using only synthesized overloads.
reveal_type(m[5]) # revealed: I0 | I1 | I2 | I3 | I5
reveal_type(m[10]) # revealed: I0 | I1 | I2 | I3 | I5

# Similarly, ideally these would just be `I0` | I1`,
# but achieving that with only synthesized overloads wouldn't be possible
reveal_type(m[-6]) # revealed: I0 | I1 | I2 | I3 | I5
reveal_type(m[-10]) # revealed: I0 | I1 | I2 | I3 | I5

class MixedSubclass2(tuple[I0, I1, *tuple[I2, ...], I3]): ...

# revealed: Overload[(self, index: Literal[-1], /) -> I3, (self, index: Literal[0], /) -> I0, (self, index: Literal[-2], /) -> I2 | I1, (self, index: Literal[2], /) -> I2 | I3, (self, index: Literal[1], /) -> I1, (self, index: Literal[-3], /) -> I2 | I1 | I0, (self, index: SupportsIndex, /) -> I0 | I1 | I2 | I3, (self, index: slice[Any, Any, Any], /) -> tuple[I0 | I1 | I2 | I3, ...]]
reveal_type(MixedSubclass2.__getitem__)

def g(m: MixedSubclass2, i: int):
reveal_type(m[0]) # revealed: I0
reveal_type(m[1]) # revealed: I1
reveal_type(m[2]) # revealed: I2 | I3

# Ideally this would just be `I2 | I3`,
# but that's not possible to achieve with synthesized overloads
reveal_type(m[3]) # revealed: I0 | I1 | I2 | I3

reveal_type(m[-1]) # revealed: I3
reveal_type(m[-2]) # revealed: I2 | I1
reveal_type(m[-3]) # revealed: I2 | I1 | I0

# Ideally this would just be `I2 | I1 | I0`,
# but that's not possible to achieve with synthesized overloads
reveal_type(m[-4]) # revealed: I0 | I1 | I2 | I3
```

The stdlib API `os.stat` is a commonly used API that returns an instance of a tuple subclass
(`os.stat_result`), and therefore provides a good integration test for tuple subclasses.

```py
import os
import stat

reveal_type(os.stat("my_file.txt")) # revealed: stat_result
reveal_type(os.stat("my_file.txt")[stat.ST_MODE]) # revealed: int
reveal_type(os.stat("my_file.txt")[stat.ST_ATIME]) # revealed: int | float

# revealed: tuple[<class 'stat_result'>, <class 'structseq[int | float]'>, <class 'tuple[int, int, int, int, int, int, int, int | float, int | float, int | float]'>, <class 'Sequence[int | float]'>, <class 'Reversible[int | float]'>, <class 'Collection[int | float]'>, <class 'Iterable[int | float]'>, <class 'Container[int | float]'>, typing.Protocol, typing.Generic, <class 'object'>]
reveal_type(os.stat_result.__mro__)

# There are no specific overloads for the `float` elements in `os.stat_result`,
# because the fallback `(self, index: SupportsIndex, /) -> int | float` overload
# gives the right result for those elements in the tuple, and we aim to synthesize
# the minimum number of overloads for any given tuple
#
# revealed: Overload[(self, index: Literal[-10, -9, -8, -7, -6, -5, -4, 0, 1, 2, 3, 4, 5, 6], /) -> int, (self, index: SupportsIndex, /) -> int | float, (self, index: slice[Any, Any, Any], /) -> tuple[int | float, ...]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stat_result is based on tuple[int, int, int, int, int, int, int, float, float, float], so why don't we see the special overload for float here? Because float splits into int | float? And int | float is also equal to the union of all element types? Makes sense... but maybe worth a comment? Or is there another stdlib API that we could use as a more interesting example?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think stat_result is a good one to test here, as there are several hits in the mypy_primer diff showing that people often index into it with an index <=6 and expect the type checker to understand that the result is an int. But I'll add a comment -- and I can add some tests with "more heterogeneous" stdlib APIs too!

reveal_type(os.stat_result.__getitem__)
```

Because of the synthesized `__getitem__` overloads we synthesize for tuples and tuple subclasses,
tuples are naturally understood as being subtypes of protocols that have precise return types from
`__getitem__` method members:

```py
from typing import Protocol, Literal
from ty_extensions import static_assert, is_subtype_of

class IntFromZeroSubscript(Protocol):
def __getitem__(self, index: Literal[0], /) -> int: ...

static_assert(is_subtype_of(tuple[int, str], IntFromZeroSubscript))

class TupleSubclass(tuple[int, str]): ...

static_assert(is_subtype_of(TupleSubclass, IntFromZeroSubscript))
```

## Slices

```py
Expand Down
181 changes: 176 additions & 5 deletions crates/ty_python_semantic/src/types/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::types::function::{DataclassTransformerParams, KnownFunction};
use crate::types::generics::{GenericContext, Specialization, walk_specialization};
use crate::types::infer::nearest_enclosing_class;
use crate::types::signatures::{CallableSignature, Parameter, Parameters, Signature};
use crate::types::tuple::TupleType;
use crate::types::tuple::{TupleSpec, TupleType};
use crate::types::{
BareTypeAliasType, Binding, BoundSuperError, BoundSuperType, CallableType, DataclassParams,
DeprecatedInstance, DynamicType, KnownInstanceType, TypeAliasType, TypeMapping, TypeRelation,
Expand Down Expand Up @@ -53,7 +53,7 @@ use ruff_db::parsed::{ParsedModuleRef, parsed_module};
use ruff_python_ast::name::Name;
use ruff_python_ast::{self as ast, PythonVersion};
use ruff_text_size::{Ranged, TextRange};
use rustc_hash::{FxHashSet, FxHasher};
use rustc_hash::{FxHashMap, FxHashSet, FxHasher};

type FxOrderMap<K, V> = ordermap::map::OrderMap<K, V, BuildHasherDefault<FxHasher>>;

Expand Down Expand Up @@ -574,8 +574,25 @@ impl<'db> ClassType<'db> {
/// directly. Use [`ClassType::class_member`] if you require a method that will
/// traverse through the MRO until it finds the member.
pub(super) fn own_class_member(self, db: &'db dyn Db, name: &str) -> PlaceAndQualifiers<'db> {
fn synthesize_getitem_overload_signature<'db>(
index_annotation: Type<'db>,
return_annotation: Type<'db>,
) -> Signature<'db> {
let self_parameter = Parameter::positional_only(Some(Name::new_static("self")));
let index_parameter = Parameter::positional_only(Some(Name::new_static("index")))
.with_annotated_type(index_annotation);
let parameters = Parameters::new([self_parameter, index_parameter]);
Signature::new(parameters, Some(return_annotation))
}

let (class_literal, specialization) = self.class_literal(db);

let fallback_member_lookup = || {
class_literal
.own_class_member(db, specialization, name)
.map_type(|ty| ty.apply_optional_specialization(db, specialization))
};

let synthesize_simple_tuple_method = |return_type| {
let parameters =
Parameters::new([Parameter::positional_only(Some(Name::new_static("self")))
Expand Down Expand Up @@ -606,6 +623,162 @@ impl<'db> ClassType<'db> {
synthesize_simple_tuple_method(return_type)
}

"__getitem__" if class_literal.is_tuple(db) => {
specialization
.map(|spec| {
let tuple = spec.tuple(db);

let mut element_type_to_indices: FxHashMap<Type<'db>, Vec<i64>> =
FxHashMap::default();

match tuple {
// E.g. for `tuple[int, str]`, we will generate the following overloads:
//
// __getitem__(self, index: Literal[0, -2], /) -> int
// __getitem__(self, index: Literal[1, -1], /) -> str
//
TupleSpec::Fixed(fixed_length_tuple) => {
let tuple_length = fixed_length_tuple.len();

for (index, ty) in fixed_length_tuple.elements().enumerate() {
let entry = element_type_to_indices.entry(*ty).or_default();
if let Ok(index) = i64::try_from(index) {
entry.push(index);
}
if let Ok(index) = i64::try_from(tuple_length - index) {
entry.push(0 - index);
}
}
}

// E.g. for `tuple[str, *tuple[float, ...], bytes, range]`, we will generate the following overloads:
//
// __getitem__(self, index: Literal[0], /) -> str
// __getitem__(self, index: Literal[1], /) -> float | bytes
// __getitem__(self, index: Literal[2], /) -> float | bytes | range
// __getitem__(self, index: Literal[-1], /) -> range
// __getitem__(self, index: Literal[-2], /) -> bytes
// __getitem__(self, index: Literal[-3], /) -> float | str
//
TupleSpec::Variable(variable_length_tuple) => {
for (index, ty) in variable_length_tuple.prefix.iter().enumerate() {
if let Ok(index) = i64::try_from(index) {
element_type_to_indices.entry(*ty).or_default().push(index);
}

let one_based_index = index + 1;

if let Ok(i) = i64::try_from(
variable_length_tuple.suffix.len() + one_based_index,
) {
let overload_return = UnionType::from_elements(
db,
std::iter::once(variable_length_tuple.variable).chain(
variable_length_tuple
.prefix
.iter()
.rev()
.take(one_based_index)
.copied(),
),
);
element_type_to_indices
.entry(overload_return)
.or_default()
.push(0 - i);
}
}

for (index, ty) in
variable_length_tuple.suffix.iter().rev().enumerate()
{
if let Some(index) =
index.checked_add(1).and_then(|i| i64::try_from(i).ok())
{
element_type_to_indices
.entry(*ty)
.or_default()
.push(0 - index);
}

if let Ok(i) =
i64::try_from(variable_length_tuple.prefix.len() + index)
{
let overload_return = UnionType::from_elements(
db,
std::iter::once(variable_length_tuple.variable).chain(
variable_length_tuple
.suffix
.iter()
.take(index + 1)
.copied(),
),
);
element_type_to_indices
.entry(overload_return)
.or_default()
.push(i);
}
}
}
}

let all_elements_unioned =
UnionType::from_elements(db, tuple.all_elements());

let mut overload_signatures =
Vec::with_capacity(element_type_to_indices.len().saturating_add(2));

overload_signatures.extend(element_type_to_indices.into_iter().filter_map(
|(return_type, mut indices)| {
if return_type.is_equivalent_to(db, all_elements_unioned) {
return None;
}

// Sorting isn't strictly required, but leads to nicer `reveal_type` output
indices.sort_unstable();

let index_annotation = UnionType::from_elements(
db,
indices.into_iter().map(Type::IntLiteral),
);

Some(synthesize_getitem_overload_signature(
index_annotation,
return_type,
))
},
));

// Fallback overloads: for `tuple[int, str]`, we will generate the following overloads:
//
// __getitem__(self, index: int, /) -> int | str
// __getitem__(self, index: slice[Any, Any, Any], /) -> tuple[int | str, ...]
//
// and for `tuple[str, *tuple[float, ...], bytes]`, we will generate the following overloads:
//
// __getitem__(self, index: int, /) -> str | float | bytes
// __getitem__(self, index: slice[Any, Any, Any], /) -> tuple[str | float | bytes, ...]
//
overload_signatures.push(synthesize_getitem_overload_signature(
KnownClass::SupportsIndex.to_instance(db),
all_elements_unioned,
));

overload_signatures.push(synthesize_getitem_overload_signature(
KnownClass::Slice.to_instance(db),
TupleType::homogeneous(db, all_elements_unioned),
));

let getitem_signature =
CallableSignature::from_overloads(overload_signatures);
let getitem_type =
Type::Callable(CallableType::new(db, getitem_signature, true));
Place::bound(getitem_type).into()
})
.unwrap_or_else(fallback_member_lookup)
}

// ```py
// class tuple:
// @overload
Expand Down Expand Up @@ -672,9 +845,7 @@ impl<'db> ClassType<'db> {
Place::bound(synthesized_dunder).into()
}

_ => class_literal
.own_class_member(db, specialization, name)
.map_type(|ty| ty.apply_optional_specialization(db, specialization)),
_ => fallback_member_lookup(),
}
}

Expand Down
Loading