Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,62 @@
# Generator expressions

## Basic

We infer specialized `GeneratorType` instance types for generator expressions:

```py
# revealed: GeneratorType[int, None, None]
reveal_type(x for x in range(10))

# revealed: GeneratorType[tuple[int, str], None, None]
reveal_type((x, str(y)) for x in range(3) for y in range(3))
```

When used in a loop, the yielded type can be inferred:

```py
squares = (x**2 for x in range(10))

for s in squares:
reveal_type(s) # revealed: int
```

`GeneratorType` is covariant in its yielded type, so it can be used where a wider yielded type is
expected:

```py
# revealed: GeneratorType[@Todo(generator expression yield type), @Todo(generator expression send type), @Todo(generator expression return type)]
reveal_type((x for x in range(42)))
from typing import Iterator

def process_numbers(x: Iterator[float]): ...

numbers = (x for x in range(10))
reveal_type(numbers) # revealed: GeneratorType[int, None, None]
process_numbers(numbers)
```

## Async generators

For async generator expressions, we infer specialized `AsyncGeneratorType` instance types:

```py
import asyncio
from typing import AsyncGenerator

async def slow_numbers() -> AsyncGenerator[int, None]:
current = 0
while True:
await asyncio.sleep(1)
yield current
current += 1

async def main() -> None:
slow_squares = (x**2 async for x in slow_numbers())

reveal_type(slow_squares) # revealed: AsyncGeneratorType[int, None]

async for s in slow_squares:
reveal_type(s) # revealed: int
print(s)

asyncio.run(main())
```
40 changes: 29 additions & 11 deletions crates/ty_python_semantic/src/types/infer/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7392,33 +7392,51 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}

/// Infer the type of the `iter` expression of the first comprehension.
fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) {
/// Returns the evaluation mode (async or sync) of the comprehension.
fn infer_first_comprehension_iter(
&mut self,
comprehensions: &[ast::Comprehension],
) -> EvaluationMode {
let mut comprehensions_iter = comprehensions.iter();
let Some(first_comprehension) = comprehensions_iter.next() else {
unreachable!("Comprehension must contain at least one generator");
};
self.infer_standalone_expression(&first_comprehension.iter, TypeContext::default());

if first_comprehension.is_async {
EvaluationMode::Async
} else {
EvaluationMode::Sync
}
}

fn infer_generator_expression(&mut self, generator: &ast::ExprGenerator) -> Type<'db> {
let ast::ExprGenerator {
range: _,
node_index: _,
elt: _,
elt,
generators,
parenthesized: _,
} = generator;

self.infer_first_comprehension_iter(generators);
let evaluation_mode = self.infer_first_comprehension_iter(generators);

KnownClass::GeneratorType.to_specialized_instance(
self.db(),
[
todo_type!("generator expression yield type"),
todo_type!("generator expression send type"),
todo_type!("generator expression return type"),
],
)
let scope_id = self
.index
.node_scope(NodeWithScopeRef::GeneratorExpression(generator));
let scope = scope_id.to_scope_id(self.db(), self.file());
let inference = infer_scope_types(self.db(), scope);
let yield_type = inference.expression_type(elt.as_ref());

if evaluation_mode.is_async() {
KnownClass::AsyncGeneratorType
.to_specialized_instance(self.db(), [yield_type, Type::none(self.db())])
} else {
KnownClass::GeneratorType.to_specialized_instance(
self.db(),
[yield_type, Type::none(self.db()), Type::none(self.db())],
)
}
}

/// Return a specialization of the collection class (list, dict, set) based on the type context and the inferred
Expand Down
Loading