diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index c76990f3f4261c..f8e4e34fe25773 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -672,6 +672,7 @@ where ForStmtDefinitionNodeRef { iterable: &node.iter, target: name_node, + is_async: node.is_async, }, ); } diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 07c36f7361afa0..8667632b920d3c 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -152,6 +152,7 @@ pub(crate) struct WithItemDefinitionNodeRef<'a> { pub(crate) struct ForStmtDefinitionNodeRef<'a> { pub(crate) iterable: &'a ast::Expr, pub(crate) target: &'a ast::ExprName, + pub(crate) is_async: bool, } #[derive(Copy, Clone, Debug)] @@ -206,12 +207,15 @@ impl DefinitionNodeRef<'_> { DefinitionNodeRef::AugmentedAssignment(augmented_assignment) => { DefinitionKind::AugmentedAssignment(AstNodeRef::new(parsed, augmented_assignment)) } - DefinitionNodeRef::For(ForStmtDefinitionNodeRef { iterable, target }) => { - DefinitionKind::For(ForStmtDefinitionKind { - iterable: AstNodeRef::new(parsed.clone(), iterable), - target: AstNodeRef::new(parsed, target), - }) - } + DefinitionNodeRef::For(ForStmtDefinitionNodeRef { + iterable, + target, + is_async, + }) => DefinitionKind::For(ForStmtDefinitionKind { + iterable: AstNodeRef::new(parsed.clone(), iterable), + target: AstNodeRef::new(parsed, target), + is_async, + }), DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { iterable, target, @@ -265,6 +269,7 @@ impl DefinitionNodeRef<'_> { Self::For(ForStmtDefinitionNodeRef { iterable: _, target, + is_async: _, }) => target.into(), Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(), Self::Parameter(node) => match node { @@ -388,6 +393,7 @@ impl WithItemDefinitionKind { pub struct ForStmtDefinitionKind { iterable: AstNodeRef, target: AstNodeRef, + is_async: bool, } impl ForStmtDefinitionKind { @@ -398,6 +404,10 @@ impl ForStmtDefinitionKind { pub(crate) fn target(&self) -> &ast::ExprName { self.target.node() } + + pub(crate) fn is_async(&self) -> bool { + self.is_async + } } #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 472a171579d244..52201c6295fe3b 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -394,6 +394,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_for_statement_definition( for_statement_definition.target(), for_statement_definition.iterable(), + for_statement_definition.is_async(), definition, ); } @@ -1045,6 +1046,7 @@ impl<'db> TypeInferenceBuilder<'db> { &mut self, target: &ast::ExprName, iterable: &ast::Expr, + is_async: bool, definition: Definition<'db>, ) { let expression = self.index.expression(iterable); @@ -1054,9 +1056,14 @@ impl<'db> TypeInferenceBuilder<'db> { .types .expression_ty(iterable.scoped_ast_id(self.db, self.scope)); - let loop_var_value_ty = iterable_ty - .iterate(self.db) - .unwrap_with_diagnostic(iterable.into(), self); + let loop_var_value_ty = if is_async { + // TODO(Alex): async iterables/iterators! + Type::Unknown + } else { + iterable_ty + .iterate(self.db) + .unwrap_with_diagnostic(iterable.into(), self) + }; self.types .expressions @@ -3026,6 +3033,62 @@ mod tests { Ok(()) } + /// This tests that we understand that `async` for loops + /// do not work according to the synchronous iteration protocol + #[test] + fn invalid_async_for_loop() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + async def foo(): + class Iterator: + def __next__(self) -> int: + return 42 + + class Iterable: + def __iter__(self) -> Iterator: + return Iterator() + + async for x in Iterator(): + pass + ", + )?; + + // TODO(Alex) async iterables/iterators! + assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown"); + + Ok(()) + } + + #[test] + fn basic_async_for_loop() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + async def foo(): + class IntAsyncIterator: + async def __anext__(self) -> int: + return 42 + + class IntAsyncIterable: + def __aiter__(self) -> IntAsyncIterator: + return IntAsyncIterator() + + async for x in IntAsyncIterable(): + pass + ", + )?; + + // TODO(Alex) async iterables/iterators! + assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown"); + + Ok(()) + } + #[test] fn class_constructor_call_expression() -> anyhow::Result<()> { let mut db = setup_db();