@@ -403,7 +403,8 @@ impl<'db> TypeInferenceBuilder<'db> {
403
403
}
404
404
DefinitionKind :: Comprehension ( comprehension) => {
405
405
self . infer_comprehension_definition (
406
- comprehension. node ( ) ,
406
+ comprehension. iterable ( ) ,
407
+ comprehension. target ( ) ,
407
408
comprehension. is_first ( ) ,
408
409
definition,
409
410
) ;
@@ -1545,11 +1546,11 @@ impl<'db> TypeInferenceBuilder<'db> {
1545
1546
1546
1547
/// Infer the type of the `iter` expression of the first comprehension.
1547
1548
fn infer_first_comprehension_iter ( & mut self , comprehensions : & [ ast:: Comprehension ] ) {
1548
- let mut generators_iter = comprehensions. iter ( ) ;
1549
- let Some ( first_generator ) = generators_iter . next ( ) else {
1549
+ let mut comprehensions_iter = comprehensions. iter ( ) ;
1550
+ let Some ( first_comprehension ) = comprehensions_iter . next ( ) else {
1550
1551
unreachable ! ( "Comprehension must contain at least one generator" ) ;
1551
1552
} ;
1552
- self . infer_expression ( & first_generator . iter ) ;
1553
+ self . infer_expression ( & first_comprehension . iter ) ;
1553
1554
}
1554
1555
1555
1556
fn infer_generator_expression ( & mut self , generator : & ast:: ExprGenerator ) -> Type < ' db > {
@@ -1615,9 +1616,7 @@ impl<'db> TypeInferenceBuilder<'db> {
1615
1616
} = generator;
1616
1617
1617
1618
self . infer_expression ( elt) ;
1618
- for comprehension in generators {
1619
- self . infer_comprehension ( comprehension) ;
1620
- }
1619
+ self . infer_comprehensions ( generators) ;
1621
1620
}
1622
1621
1623
1622
fn infer_list_comprehension_expression_scope ( & mut self , listcomp : & ast:: ExprListComp ) {
@@ -1628,9 +1627,7 @@ impl<'db> TypeInferenceBuilder<'db> {
1628
1627
} = listcomp;
1629
1628
1630
1629
self . infer_expression ( elt) ;
1631
- for comprehension in generators {
1632
- self . infer_comprehension ( comprehension) ;
1633
- }
1630
+ self . infer_comprehensions ( generators) ;
1634
1631
}
1635
1632
1636
1633
fn infer_dict_comprehension_expression_scope ( & mut self , dictcomp : & ast:: ExprDictComp ) {
@@ -1643,9 +1640,7 @@ impl<'db> TypeInferenceBuilder<'db> {
1643
1640
1644
1641
self . infer_expression ( key) ;
1645
1642
self . infer_expression ( value) ;
1646
- for comprehension in generators {
1647
- self . infer_comprehension ( comprehension) ;
1648
- }
1643
+ self . infer_comprehensions ( generators) ;
1649
1644
}
1650
1645
1651
1646
fn infer_set_comprehension_expression_scope ( & mut self , setcomp : & ast:: ExprSetComp ) {
@@ -1656,37 +1651,68 @@ impl<'db> TypeInferenceBuilder<'db> {
1656
1651
} = setcomp;
1657
1652
1658
1653
self . infer_expression ( elt) ;
1659
- for comprehension in generators {
1660
- self . infer_comprehension ( comprehension) ;
1661
- }
1654
+ self . infer_comprehensions ( generators) ;
1662
1655
}
1663
1656
1664
- fn infer_comprehension ( & mut self , comprehension : & ast:: Comprehension ) {
1665
- self . infer_definition ( comprehension) ;
1666
- for expr in & comprehension. ifs {
1667
- self . infer_expression ( expr) ;
1657
+ fn infer_comprehensions ( & mut self , comprehensions : & [ ast:: Comprehension ] ) {
1658
+ let mut comprehensions_iter = comprehensions. iter ( ) ;
1659
+ let Some ( first_comprehension) = comprehensions_iter. next ( ) else {
1660
+ unreachable ! ( "Comprehension must contain at least one generator" ) ;
1661
+ } ;
1662
+ self . infer_comprehension ( first_comprehension, true ) ;
1663
+ for comprehension in comprehensions_iter {
1664
+ self . infer_comprehension ( comprehension, false ) ;
1668
1665
}
1669
1666
}
1670
1667
1671
- fn infer_comprehension_definition (
1672
- & mut self ,
1673
- comprehension : & ast:: Comprehension ,
1674
- is_first : bool ,
1675
- definition : Definition < ' db > ,
1676
- ) {
1668
+ fn infer_comprehension ( & mut self , comprehension : & ast:: Comprehension , is_first : bool ) {
1677
1669
let ast:: Comprehension {
1678
1670
range : _,
1679
1671
target,
1680
1672
iter,
1681
- ifs : _ ,
1673
+ ifs,
1682
1674
is_async : _,
1683
1675
} = comprehension;
1684
1676
1685
1677
if !is_first {
1686
1678
self . infer_expression ( iter) ;
1687
1679
}
1688
- // TODO(dhruvmanila): The target type should be inferred based on the iter type instead.
1689
- let target_ty = self . infer_expression ( target) ;
1680
+ // TODO more complex assignment targets
1681
+ if let ast:: Expr :: Name ( name) = target {
1682
+ self . infer_definition ( name) ;
1683
+ } else {
1684
+ self . infer_expression ( target) ;
1685
+ }
1686
+ for expr in ifs {
1687
+ self . infer_expression ( expr) ;
1688
+ }
1689
+ }
1690
+
1691
+ fn infer_comprehension_definition (
1692
+ & mut self ,
1693
+ iterable : & ast:: Expr ,
1694
+ target : & ast:: ExprName ,
1695
+ is_first : bool ,
1696
+ definition : Definition < ' db > ,
1697
+ ) {
1698
+ if !is_first {
1699
+ let expression = self . index . expression ( iterable) ;
1700
+ let result = infer_expression_types ( self . db , expression) ;
1701
+ self . extend ( result) ;
1702
+ let _iterable_ty = self
1703
+ . types
1704
+ . expression_ty ( iterable. scoped_ast_id ( self . db , self . scope ) ) ;
1705
+ }
1706
+ // TODO(dhruvmanila): The iter type for the first comprehension is coming from the
1707
+ // enclosing scope.
1708
+
1709
+ // TODO(dhruvmanila): The target type should be inferred based on the iter type instead,
1710
+ // similar to how it's done in `infer_for_statement_definition`.
1711
+ let target_ty = Type :: Unknown ;
1712
+
1713
+ self . types
1714
+ . expressions
1715
+ . insert ( target. scoped_ast_id ( self . db , self . scope ) , target_ty) ;
1690
1716
self . types . definitions . insert ( definition, target_ty) ;
1691
1717
}
1692
1718
0 commit comments