Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name_1
{0: 0 for unique_name_0 in unique_name_1 if name_1}


@[name_2 for unique_name_2 in name_2]
def name_2():
pass


def name_2():
pass


match 0:
case name_2():
pass
case []:
name_1 = 0
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,92 @@ async def _():
# revealed: Unknown
[reveal_type(x) async for x in range(3)]
```

## Comprehension expression types

The type of the comprehension expression itself should reflect the inferred element type:

```py
from typing import TypedDict, Literal

# revealed: list[int | Unknown]
reveal_type([x for x in range(10)])

# revealed: set[int | Unknown]
reveal_type({x for x in range(10)})

# revealed: dict[int | Unknown, str | Unknown]
reveal_type({x: str(x) for x in range(10)})

# revealed: list[tuple[int, Unknown | str] | Unknown]
reveal_type([(x, y) for x in range(5) for y in ["a", "b", "c"]])

squares: list[int | None] = [x**2 for x in range(10)]
reveal_type(squares) # revealed: list[int | None]
```

Inference for comprehensions takes the type context into account:

```py
# Without type context:
reveal_type([x for x in [1, 2, 3]]) # revealed: list[Unknown | int]
reveal_type({x: "a" for x in [1, 2, 3]}) # revealed: dict[Unknown | int, str | Unknown]
reveal_type({str(x): x for x in [1, 2, 3]}) # revealed: dict[str | Unknown, Unknown | int]
reveal_type({x for x in [1, 2, 3]}) # revealed: set[Unknown | int]

# With type context:
xs: list[int] = [x for x in [1, 2, 3]]
reveal_type(xs) # revealed: list[int]

ys: dict[int, str] = {x: str(x) for x in [1, 2, 3]}
reveal_type(ys) # revealed: dict[int, str]

zs: set[int] = {x for x in [1, 2, 3]}
```

This also works for nested comprehensions:

```py
table = [[(x, y) for x in range(3)] for y in range(3)]
reveal_type(table) # revealed: list[list[tuple[int, int] | Unknown] | Unknown]

table_with_content: list[list[tuple[int, int, str | None]]] = [[(x, y, None) for x in range(3)] for y in range(3)]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed with @ibraheemdev, solving this will require pushing down the type context into infer_scope_types in order to infer the None in (x, y, None) as str | None inside the comprehension. Similar for the TypedDict test below. This change is a bit more invasive, as it requires us to skip comprehension scopes in check_types, because we would not have the required type context available when simply looping over all scopes. This change would be fine because with this PR, we're now calling infer_scope_types for comprehension scopes when we check the outer scope. A similar change would probably be required for lambdas as well.

reveal_type(table_with_content) # revealed: list[list[tuple[int, int, str | None]]]
```

The type context is propagated down into the comprehension:

```py
class Person(TypedDict):
name: str

persons: list[Person] = [{"name": n} for n in ["Alice", "Bob"]]
reveal_type(persons) # revealed: list[Person]

# TODO: This should be an error
invalid: list[Person] = [{"misspelled": n} for n in ["Alice", "Bob"]]
```

We promote literals to avoid overly-precise types in invariant positions:

```py
reveal_type([x for x in ("a", "b", "c")]) # revealed: list[str | Unknown]
reveal_type({x for x in (1, 2, 3)}) # revealed: set[int | Unknown]
reveal_type({k: 0 for k in ("a", "b", "c")}) # revealed: dict[str | Unknown, int | Unknown]
```

Type context can prevent this promotion from happening:

```py
list_of_literals: list[Literal["a", "b", "c"]] = [x for x in ("a", "b", "c")]
reveal_type(list_of_literals) # revealed: list[Literal["a", "b", "c"]]

dict_with_literal_keys: dict[Literal["a", "b", "c"], int] = {k: 0 for k in ("a", "b", "c")}
reveal_type(dict_with_literal_keys) # revealed: dict[Literal["a", "b", "c"], int]

dict_with_literal_values: dict[str, Literal[1, 2, 3]] = {str(k): k for k in (1, 2, 3)}
reveal_type(dict_with_literal_values) # revealed: dict[str, Literal[1, 2, 3]]

set_with_literals: set[Literal[1, 2, 3]] = {k for k in (1, 2, 3)}
reveal_type(set_with_literals) # revealed: set[Literal[1, 2, 3]]
```
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ reveal_type({"a": 1, "b": (1, 2), "c": (1, 2, 3)})
## Dict comprehensions

```py
# revealed: dict[@Todo(dict comprehension key type), @Todo(dict comprehension value type)]
# revealed: dict[int | Unknown, int | Unknown]
reveal_type({x: y for x, y in enumerate(range(42))})
```
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ reveal_type([1, (1, 2), (1, 2, 3)])
## List comprehensions

```py
reveal_type([x for x in range(42)]) # revealed: list[@Todo(list comprehension element type)]
reveal_type([x for x in range(42)]) # revealed: list[int | Unknown]
```
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ reveal_type({1, (1, 2), (1, 2, 3)})
## Set comprehensions

```py
reveal_type({x for x in range(42)}) # revealed: set[@Todo(set comprehension element type)]
reveal_type({x for x in range(42)}) # revealed: set[int | Unknown]
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Documentation of two fuzzer panics involving comprehensions

Type inference for comprehensions was added in <https://github.com/astral-sh/ruff/pull/20962>. It
added two new fuzzer panics that are documented here for regression testing.

## Too many cycle iterations in `place_by_id`

<!-- expect-panic: too many cycle iterations -->

```py
name_5(name_3)
[0 for unique_name_0 in unique_name_1 for unique_name_2 in name_3]

@{name_3 for unique_name_3 in unique_name_4}
class name_4[**name_3](0, name_2=name_5):
pass

try:
name_0 = name_4
except* 0:
pass
else:
match unique_name_12:
case 0:
from name_2 import name_3
case name_0():

@name_4
def name_3():
pass

(name_3 := 0)

@name_3
async def name_5():
pass
```

## Too many cycle iterations in `infer_definition_types`

<!-- expect-panic: too many cycle iterations -->

```py
for name_1 in {
{{0: name_4 for unique_name_0 in unique_name_1}: 0 for unique_name_2 in unique_name_3 if name_4}: 0
for unique_name_4 in name_1
for name_4 in name_1
}:
pass
```
10 changes: 9 additions & 1 deletion crates/ty_python_semantic/src/types/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,14 @@ pub struct FunctionLiteral<'db> {
// The Salsa heap is tracked separately.
impl get_size2::GetSize for FunctionLiteral<'_> {}

fn overloads_and_implementation_cycle_initial<'db>(
_db: &'db dyn Db,
_id: salsa::Id,
_self: FunctionLiteral<'db>,
) -> (Box<[OverloadLiteral<'db>]>, Option<OverloadLiteral<'db>>) {
(Box::new([]), None)
}

#[salsa::tracked]
impl<'db> FunctionLiteral<'db> {
fn name(self, db: &'db dyn Db) -> &'db ast::name::Name {
Expand Down Expand Up @@ -576,7 +584,7 @@ impl<'db> FunctionLiteral<'db> {
self.last_definition(db).spans(db)
}

#[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size)]
#[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size, cycle_initial=overloads_and_implementation_cycle_initial)]
fn overloads_and_implementation(
self,
db: &'db dyn Db,
Expand Down
115 changes: 94 additions & 21 deletions crates/ty_python_semantic/src/types/infer/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5943,9 +5943,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast::Expr::Set(set) => self.infer_set_expression(set, tcx),
ast::Expr::Dict(dict) => self.infer_dict_expression(dict, tcx),
ast::Expr::Generator(generator) => self.infer_generator_expression(generator),
ast::Expr::ListComp(listcomp) => self.infer_list_comprehension_expression(listcomp),
ast::Expr::DictComp(dictcomp) => self.infer_dict_comprehension_expression(dictcomp),
ast::Expr::SetComp(setcomp) => self.infer_set_comprehension_expression(setcomp),
ast::Expr::ListComp(listcomp) => {
self.infer_list_comprehension_expression(listcomp, tcx)
}
ast::Expr::DictComp(dictcomp) => {
self.infer_dict_comprehension_expression(dictcomp, tcx)
}
ast::Expr::SetComp(setcomp) => self.infer_set_comprehension_expression(setcomp, tcx),
ast::Expr::Name(name) => self.infer_name_expression(name),
ast::Expr::Attribute(attribute) => self.infer_attribute_expression(attribute),
ast::Expr::UnaryOp(unary_op) => self.infer_unary_expression(unary_op),
Expand Down Expand Up @@ -6450,52 +6454,121 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
)
}

fn infer_list_comprehension_expression(&mut self, listcomp: &ast::ExprListComp) -> Type<'db> {
/// Return a specialization of the collection class (list, dict, set) based on the type context and the inferred
/// element / key-value types from the comprehension expression.
fn infer_comprehension_specialization(
&self,
collection_class: KnownClass,
inferred_element_types: &[Type<'db>],
tcx: TypeContext<'db>,
) -> Type<'db> {
// Remove any union elements of that are unrelated to the collection type.
let tcx = tcx.map(|annotation| {
annotation.filter_disjoint_elements(
self.db(),
collection_class.to_instance(self.db()),
InferableTypeVars::None,
)
});

if let Some(annotated_element_types) = tcx
.known_specialization(self.db(), collection_class)
.map(|specialization| specialization.types(self.db()))
&& annotated_element_types
.iter()
.zip(inferred_element_types.iter())
.all(|(annotated, inferred)| inferred.is_assignable_to(self.db(), *annotated))
{
collection_class
.to_specialized_instance(self.db(), annotated_element_types.iter().copied())
} else {
collection_class.to_specialized_instance(
self.db(),
inferred_element_types.iter().map(|ty| {
UnionType::from_elements(
self.db(),
[
ty.promote_literals(self.db(), TypeContext::default()),
Type::unknown(),
],
)
}),
)
}
}

fn infer_list_comprehension_expression(
&mut self,
listcomp: &ast::ExprListComp,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprListComp {
range: _,
node_index: _,
elt: _,
elt,
generators,
} = listcomp;

self.infer_first_comprehension_iter(generators);

KnownClass::List
.to_specialized_instance(self.db(), [todo_type!("list comprehension element type")])
let scope_id = self
.index
.node_scope(NodeWithScopeRef::ListComprehension(listcomp));
let scope = scope_id.to_scope_id(self.db(), self.file());
let inference = infer_scope_types(self.db(), scope);
let element_type = inference.expression_type(elt.as_ref());

self.infer_comprehension_specialization(KnownClass::List, &[element_type], tcx)
}

fn infer_dict_comprehension_expression(&mut self, dictcomp: &ast::ExprDictComp) -> Type<'db> {
fn infer_dict_comprehension_expression(
&mut self,
dictcomp: &ast::ExprDictComp,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprDictComp {
range: _,
node_index: _,
key: _,
value: _,
key,
value,
generators,
} = dictcomp;

self.infer_first_comprehension_iter(generators);

KnownClass::Dict.to_specialized_instance(
self.db(),
[
todo_type!("dict comprehension key type"),
todo_type!("dict comprehension value type"),
],
)
let scope_id = self
.index
.node_scope(NodeWithScopeRef::DictComprehension(dictcomp));
let scope = scope_id.to_scope_id(self.db(), self.file());
let inference = infer_scope_types(self.db(), scope);
let key_type = inference.expression_type(key.as_ref());
let value_type = inference.expression_type(value.as_ref());

self.infer_comprehension_specialization(KnownClass::Dict, &[key_type, value_type], tcx)
}

fn infer_set_comprehension_expression(&mut self, setcomp: &ast::ExprSetComp) -> Type<'db> {
fn infer_set_comprehension_expression(
&mut self,
setcomp: &ast::ExprSetComp,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprSetComp {
range: _,
node_index: _,
elt: _,
elt,
generators,
} = setcomp;

self.infer_first_comprehension_iter(generators);

KnownClass::Set
.to_specialized_instance(self.db(), [todo_type!("set comprehension element type")])
let scope_id = self
.index
.node_scope(NodeWithScopeRef::SetComprehension(setcomp));
let scope = scope_id.to_scope_id(self.db(), self.file());
let inference = infer_scope_types(self.db(), scope);
let element_type = inference.expression_type(elt.as_ref());

self.infer_comprehension_specialization(KnownClass::Set, &[element_type], tcx)
}

fn infer_generator_expression_scope(&mut self, generator: &ast::ExprGenerator) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}

ast::Expr::DictComp(dictcomp) => {
self.infer_dict_comprehension_expression(dictcomp);
self.infer_dict_comprehension_expression(dictcomp, TypeContext::default());
self.report_invalid_type_expression(
expression,
format_args!("Dict comprehensions are not allowed in type expressions"),
Expand All @@ -355,7 +355,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}

ast::Expr::ListComp(listcomp) => {
self.infer_list_comprehension_expression(listcomp);
self.infer_list_comprehension_expression(listcomp, TypeContext::default());
self.report_invalid_type_expression(
expression,
format_args!("List comprehensions are not allowed in type expressions"),
Expand All @@ -364,7 +364,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}

ast::Expr::SetComp(setcomp) => {
self.infer_set_comprehension_expression(setcomp);
self.infer_set_comprehension_expression(setcomp, TypeContext::default());
self.report_invalid_type_expression(
expression,
format_args!("Set comprehensions are not allowed in type expressions"),
Expand Down
Loading