Skip to content

Commit e1e9143

Browse files
authored
[red-knot] Handle multiple comprehension targets (#13213)
## Summary Part of #13085, this PR updates the comprehension definition to handle multiple targets. ## Test Plan Update existing semantic index test case for comprehension with multiple targets. Running corpus tests shouldn't panic.
1 parent 3c4ec82 commit e1e9143

File tree

4 files changed

+101
-51
lines changed

4 files changed

+101
-51
lines changed

crates/red_knot_python_semantic/src/semantic_index.rs

+19-4
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
666666
fn comprehension_scope() {
667667
let TestCase { db, file } = test_case(
668668
"
669-
[x for x in iter1]
669+
[x for x, y in iter1]
670670
",
671671
);
672672

@@ -690,7 +690,22 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
690690

691691
let comprehension_symbol_table = index.symbol_table(comprehension_scope_id);
692692

693-
assert_eq!(names(&comprehension_symbol_table), vec!["x"]);
693+
assert_eq!(names(&comprehension_symbol_table), vec!["x", "y"]);
694+
695+
let use_def = index.use_def_map(comprehension_scope_id);
696+
for name in ["x", "y"] {
697+
let definition = use_def
698+
.first_public_definition(
699+
comprehension_symbol_table
700+
.symbol_id_by_name(name)
701+
.expect("symbol exists"),
702+
)
703+
.unwrap();
704+
assert!(matches!(
705+
definition.node(&db),
706+
DefinitionKind::Comprehension(_)
707+
));
708+
}
694709
}
695710

696711
/// Test case to validate that the `x` variable used in the comprehension is referencing the
@@ -730,8 +745,8 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
730745
let DefinitionKind::Comprehension(comprehension) = definition.node(&db) else {
731746
panic!("expected generator definition")
732747
};
733-
let ast::Comprehension { target, .. } = comprehension.node();
734-
let name = target.as_name_expr().unwrap().id().as_str();
748+
let target = comprehension.target();
749+
let name = target.id().as_str();
735750

736751
assert_eq!(name, "x");
737752
assert_eq!(target.range(), TextRange::new(23.into(), 24.into()));

crates/red_knot_python_semantic/src/semantic_index/builder.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ impl<'db> SemanticIndexBuilder<'db> {
285285

286286
// The `iter` of the first generator is evaluated in the outer scope, while all subsequent
287287
// nodes are evaluated in the inner scope.
288+
self.add_standalone_expression(&generator.iter);
288289
self.visit_expr(&generator.iter);
289290
self.push_scope(scope);
290291

@@ -300,6 +301,7 @@ impl<'db> SemanticIndexBuilder<'db> {
300301
}
301302

302303
for generator in generators_iter {
304+
self.add_standalone_expression(&generator.iter);
303305
self.visit_expr(&generator.iter);
304306

305307
self.current_assignment = Some(CurrentAssignment::Comprehension {
@@ -678,7 +680,11 @@ where
678680
Some(CurrentAssignment::Comprehension { node, first }) => {
679681
self.add_definition(
680682
symbol,
681-
ComprehensionDefinitionNodeRef { node, first },
683+
ComprehensionDefinitionNodeRef {
684+
iterable: &node.iter,
685+
target: name_node,
686+
first,
687+
},
682688
);
683689
}
684690
Some(CurrentAssignment::WithItem(with_item)) => {

crates/red_knot_python_semantic/src/semantic_index/definition.rs

+20-17
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ pub(crate) struct ForStmtDefinitionNodeRef<'a> {
156156

157157
#[derive(Copy, Clone, Debug)]
158158
pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
159-
pub(crate) node: &'a ast::Comprehension,
159+
pub(crate) iterable: &'a ast::Expr,
160+
pub(crate) target: &'a ast::ExprName,
160161
pub(crate) first: bool,
161162
}
162163

@@ -211,12 +212,15 @@ impl DefinitionNodeRef<'_> {
211212
target: AstNodeRef::new(parsed, target),
212213
})
213214
}
214-
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { node, first }) => {
215-
DefinitionKind::Comprehension(ComprehensionDefinitionKind {
216-
node: AstNodeRef::new(parsed, node),
217-
first,
218-
})
219-
}
215+
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef {
216+
iterable,
217+
target,
218+
first,
219+
}) => DefinitionKind::Comprehension(ComprehensionDefinitionKind {
220+
iterable: AstNodeRef::new(parsed.clone(), iterable),
221+
target: AstNodeRef::new(parsed, target),
222+
first,
223+
}),
220224
DefinitionNodeRef::Parameter(parameter) => match parameter {
221225
ast::AnyParameterRef::Variadic(parameter) => {
222226
DefinitionKind::Parameter(AstNodeRef::new(parsed, parameter))
@@ -262,7 +266,7 @@ impl DefinitionNodeRef<'_> {
262266
iterable: _,
263267
target,
264268
}) => target.into(),
265-
Self::Comprehension(ComprehensionDefinitionNodeRef { node, first: _ }) => node.into(),
269+
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(),
266270
Self::Parameter(node) => match node {
267271
ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
268272
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
@@ -313,13 +317,18 @@ impl MatchPatternDefinitionKind {
313317

314318
#[derive(Clone, Debug)]
315319
pub struct ComprehensionDefinitionKind {
316-
node: AstNodeRef<ast::Comprehension>,
320+
iterable: AstNodeRef<ast::Expr>,
321+
target: AstNodeRef<ast::ExprName>,
317322
first: bool,
318323
}
319324

320325
impl ComprehensionDefinitionKind {
321-
pub(crate) fn node(&self) -> &ast::Comprehension {
322-
self.node.node()
326+
pub(crate) fn iterable(&self) -> &ast::Expr {
327+
self.iterable.node()
328+
}
329+
330+
pub(crate) fn target(&self) -> &ast::ExprName {
331+
self.target.node()
323332
}
324333

325334
pub(crate) fn is_first(&self) -> bool {
@@ -442,12 +451,6 @@ impl From<&ast::StmtFor> for DefinitionNodeKey {
442451
}
443452
}
444453

445-
impl From<&ast::Comprehension> for DefinitionNodeKey {
446-
fn from(node: &ast::Comprehension) -> Self {
447-
Self(NodeKey::from_node(node))
448-
}
449-
}
450-
451454
impl From<&ast::Parameter> for DefinitionNodeKey {
452455
fn from(node: &ast::Parameter) -> Self {
453456
Self(NodeKey::from_node(node))

crates/red_knot_python_semantic/src/types/infer.rs

+55-29
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,8 @@ impl<'db> TypeInferenceBuilder<'db> {
403403
}
404404
DefinitionKind::Comprehension(comprehension) => {
405405
self.infer_comprehension_definition(
406-
comprehension.node(),
406+
comprehension.iterable(),
407+
comprehension.target(),
407408
comprehension.is_first(),
408409
definition,
409410
);
@@ -1545,11 +1546,11 @@ impl<'db> TypeInferenceBuilder<'db> {
15451546

15461547
/// Infer the type of the `iter` expression of the first comprehension.
15471548
fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) {
1548-
let mut generators_iter = comprehensions.iter();
1549-
let Some(first_generator) = generators_iter.next() else {
1549+
let mut comprehensions_iter = comprehensions.iter();
1550+
let Some(first_comprehension) = comprehensions_iter.next() else {
15501551
unreachable!("Comprehension must contain at least one generator");
15511552
};
1552-
self.infer_expression(&first_generator.iter);
1553+
self.infer_expression(&first_comprehension.iter);
15531554
}
15541555

15551556
fn infer_generator_expression(&mut self, generator: &ast::ExprGenerator) -> Type<'db> {
@@ -1615,9 +1616,7 @@ impl<'db> TypeInferenceBuilder<'db> {
16151616
} = generator;
16161617

16171618
self.infer_expression(elt);
1618-
for comprehension in generators {
1619-
self.infer_comprehension(comprehension);
1620-
}
1619+
self.infer_comprehensions(generators);
16211620
}
16221621

16231622
fn infer_list_comprehension_expression_scope(&mut self, listcomp: &ast::ExprListComp) {
@@ -1628,9 +1627,7 @@ impl<'db> TypeInferenceBuilder<'db> {
16281627
} = listcomp;
16291628

16301629
self.infer_expression(elt);
1631-
for comprehension in generators {
1632-
self.infer_comprehension(comprehension);
1633-
}
1630+
self.infer_comprehensions(generators);
16341631
}
16351632

16361633
fn infer_dict_comprehension_expression_scope(&mut self, dictcomp: &ast::ExprDictComp) {
@@ -1643,9 +1640,7 @@ impl<'db> TypeInferenceBuilder<'db> {
16431640

16441641
self.infer_expression(key);
16451642
self.infer_expression(value);
1646-
for comprehension in generators {
1647-
self.infer_comprehension(comprehension);
1648-
}
1643+
self.infer_comprehensions(generators);
16491644
}
16501645

16511646
fn infer_set_comprehension_expression_scope(&mut self, setcomp: &ast::ExprSetComp) {
@@ -1656,37 +1651,68 @@ impl<'db> TypeInferenceBuilder<'db> {
16561651
} = setcomp;
16571652

16581653
self.infer_expression(elt);
1659-
for comprehension in generators {
1660-
self.infer_comprehension(comprehension);
1661-
}
1654+
self.infer_comprehensions(generators);
16621655
}
16631656

1664-
fn infer_comprehension(&mut self, comprehension: &ast::Comprehension) {
1665-
self.infer_definition(comprehension);
1666-
for expr in &comprehension.ifs {
1667-
self.infer_expression(expr);
1657+
fn infer_comprehensions(&mut self, comprehensions: &[ast::Comprehension]) {
1658+
let mut comprehensions_iter = comprehensions.iter();
1659+
let Some(first_comprehension) = comprehensions_iter.next() else {
1660+
unreachable!("Comprehension must contain at least one generator");
1661+
};
1662+
self.infer_comprehension(first_comprehension, true);
1663+
for comprehension in comprehensions_iter {
1664+
self.infer_comprehension(comprehension, false);
16681665
}
16691666
}
16701667

1671-
fn infer_comprehension_definition(
1672-
&mut self,
1673-
comprehension: &ast::Comprehension,
1674-
is_first: bool,
1675-
definition: Definition<'db>,
1676-
) {
1668+
fn infer_comprehension(&mut self, comprehension: &ast::Comprehension, is_first: bool) {
16771669
let ast::Comprehension {
16781670
range: _,
16791671
target,
16801672
iter,
1681-
ifs: _,
1673+
ifs,
16821674
is_async: _,
16831675
} = comprehension;
16841676

16851677
if !is_first {
16861678
self.infer_expression(iter);
16871679
}
1688-
// TODO(dhruvmanila): The target type should be inferred based on the iter type instead.
1689-
let target_ty = self.infer_expression(target);
1680+
// TODO more complex assignment targets
1681+
if let ast::Expr::Name(name) = target {
1682+
self.infer_definition(name);
1683+
} else {
1684+
self.infer_expression(target);
1685+
}
1686+
for expr in ifs {
1687+
self.infer_expression(expr);
1688+
}
1689+
}
1690+
1691+
fn infer_comprehension_definition(
1692+
&mut self,
1693+
iterable: &ast::Expr,
1694+
target: &ast::ExprName,
1695+
is_first: bool,
1696+
definition: Definition<'db>,
1697+
) {
1698+
if !is_first {
1699+
let expression = self.index.expression(iterable);
1700+
let result = infer_expression_types(self.db, expression);
1701+
self.extend(result);
1702+
let _iterable_ty = self
1703+
.types
1704+
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));
1705+
}
1706+
// TODO(dhruvmanila): The iter type for the first comprehension is coming from the
1707+
// enclosing scope.
1708+
1709+
// TODO(dhruvmanila): The target type should be inferred based on the iter type instead,
1710+
// similar to how it's done in `infer_for_statement_definition`.
1711+
let target_ty = Type::Unknown;
1712+
1713+
self.types
1714+
.expressions
1715+
.insert(target.scoped_ast_id(self.db, self.scope), target_ty);
16901716
self.types.definitions.insert(definition, target_ty);
16911717
}
16921718

0 commit comments

Comments
 (0)