Skip to content

Commit e965f9c

Browse files
authored
[red-knot] Infer Unknown for the loop var in async for loops (#13243)
1 parent 0512428 commit e965f9c

File tree

3 files changed

+83
-9
lines changed

3 files changed

+83
-9
lines changed

crates/red_knot_python_semantic/src/semantic_index/builder.rs

+1
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,7 @@ where
672672
ForStmtDefinitionNodeRef {
673673
iterable: &node.iter,
674674
target: name_node,
675+
is_async: node.is_async,
675676
},
676677
);
677678
}

crates/red_knot_python_semantic/src/semantic_index/definition.rs

+16-6
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ pub(crate) struct WithItemDefinitionNodeRef<'a> {
152152
pub(crate) struct ForStmtDefinitionNodeRef<'a> {
153153
pub(crate) iterable: &'a ast::Expr,
154154
pub(crate) target: &'a ast::ExprName,
155+
pub(crate) is_async: bool,
155156
}
156157

157158
#[derive(Copy, Clone, Debug)]
@@ -206,12 +207,15 @@ impl DefinitionNodeRef<'_> {
206207
DefinitionNodeRef::AugmentedAssignment(augmented_assignment) => {
207208
DefinitionKind::AugmentedAssignment(AstNodeRef::new(parsed, augmented_assignment))
208209
}
209-
DefinitionNodeRef::For(ForStmtDefinitionNodeRef { iterable, target }) => {
210-
DefinitionKind::For(ForStmtDefinitionKind {
211-
iterable: AstNodeRef::new(parsed.clone(), iterable),
212-
target: AstNodeRef::new(parsed, target),
213-
})
214-
}
210+
DefinitionNodeRef::For(ForStmtDefinitionNodeRef {
211+
iterable,
212+
target,
213+
is_async,
214+
}) => DefinitionKind::For(ForStmtDefinitionKind {
215+
iterable: AstNodeRef::new(parsed.clone(), iterable),
216+
target: AstNodeRef::new(parsed, target),
217+
is_async,
218+
}),
215219
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef {
216220
iterable,
217221
target,
@@ -265,6 +269,7 @@ impl DefinitionNodeRef<'_> {
265269
Self::For(ForStmtDefinitionNodeRef {
266270
iterable: _,
267271
target,
272+
is_async: _,
268273
}) => target.into(),
269274
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(),
270275
Self::Parameter(node) => match node {
@@ -388,6 +393,7 @@ impl WithItemDefinitionKind {
388393
pub struct ForStmtDefinitionKind {
389394
iterable: AstNodeRef<ast::Expr>,
390395
target: AstNodeRef<ast::ExprName>,
396+
is_async: bool,
391397
}
392398

393399
impl ForStmtDefinitionKind {
@@ -398,6 +404,10 @@ impl ForStmtDefinitionKind {
398404
pub(crate) fn target(&self) -> &ast::ExprName {
399405
self.target.node()
400406
}
407+
408+
pub(crate) fn is_async(&self) -> bool {
409+
self.is_async
410+
}
401411
}
402412

403413
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]

crates/red_knot_python_semantic/src/types/infer.rs

+66-3
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ impl<'db> TypeInferenceBuilder<'db> {
394394
self.infer_for_statement_definition(
395395
for_statement_definition.target(),
396396
for_statement_definition.iterable(),
397+
for_statement_definition.is_async(),
397398
definition,
398399
);
399400
}
@@ -1045,6 +1046,7 @@ impl<'db> TypeInferenceBuilder<'db> {
10451046
&mut self,
10461047
target: &ast::ExprName,
10471048
iterable: &ast::Expr,
1049+
is_async: bool,
10481050
definition: Definition<'db>,
10491051
) {
10501052
let expression = self.index.expression(iterable);
@@ -1054,9 +1056,14 @@ impl<'db> TypeInferenceBuilder<'db> {
10541056
.types
10551057
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));
10561058

1057-
let loop_var_value_ty = iterable_ty
1058-
.iterate(self.db)
1059-
.unwrap_with_diagnostic(iterable.into(), self);
1059+
let loop_var_value_ty = if is_async {
1060+
// TODO(Alex): async iterables/iterators!
1061+
Type::Unknown
1062+
} else {
1063+
iterable_ty
1064+
.iterate(self.db)
1065+
.unwrap_with_diagnostic(iterable.into(), self)
1066+
};
10601067

10611068
self.types
10621069
.expressions
@@ -3026,6 +3033,62 @@ mod tests {
30263033
Ok(())
30273034
}
30283035

3036+
/// This tests that we understand that `async` for loops
3037+
/// do not work according to the synchronous iteration protocol
3038+
#[test]
3039+
fn invalid_async_for_loop() -> anyhow::Result<()> {
3040+
let mut db = setup_db();
3041+
3042+
db.write_dedented(
3043+
"src/a.py",
3044+
"
3045+
async def foo():
3046+
class Iterator:
3047+
def __next__(self) -> int:
3048+
return 42
3049+
3050+
class Iterable:
3051+
def __iter__(self) -> Iterator:
3052+
return Iterator()
3053+
3054+
async for x in Iterator():
3055+
pass
3056+
",
3057+
)?;
3058+
3059+
// TODO(Alex) async iterables/iterators!
3060+
assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown");
3061+
3062+
Ok(())
3063+
}
3064+
3065+
#[test]
3066+
fn basic_async_for_loop() -> anyhow::Result<()> {
3067+
let mut db = setup_db();
3068+
3069+
db.write_dedented(
3070+
"src/a.py",
3071+
"
3072+
async def foo():
3073+
class IntAsyncIterator:
3074+
async def __anext__(self) -> int:
3075+
return 42
3076+
3077+
class IntAsyncIterable:
3078+
def __aiter__(self) -> IntAsyncIterator:
3079+
return IntAsyncIterator()
3080+
3081+
async for x in IntAsyncIterable():
3082+
pass
3083+
",
3084+
)?;
3085+
3086+
// TODO(Alex) async iterables/iterators!
3087+
assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown");
3088+
3089+
Ok(())
3090+
}
3091+
30293092
#[test]
30303093
fn class_constructor_call_expression() -> anyhow::Result<()> {
30313094
let mut db = setup_db();

0 commit comments

Comments
 (0)