Skip to content

Commit 4d2ee5b

Browse files
Add named expression handling to find_assigned_value (#9109)
1 parent 8314c8b commit 4d2ee5b

File tree

4 files changed

+147
-106
lines changed

4 files changed

+147
-106
lines changed

crates/ruff_linter/resources/test/fixtures/flake8_trio/TRIO115.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ async def func():
2929
trio.sleep(e) # TRIO115
3030

3131
m_x, m_y = 0
32-
trio.sleep(m_y) # TRIO115
33-
trio.sleep(m_x) # TRIO115
32+
trio.sleep(m_y) # OK
33+
trio.sleep(m_x) # OK
3434

3535
m_a = m_b = 0
3636
trio.sleep(m_a) # TRIO115
@@ -43,6 +43,8 @@ async def func():
4343

4444

4545
def func():
46+
import trio
47+
4648
trio.run(trio.sleep(0)) # TRIO115
4749

4850

@@ -55,3 +57,10 @@ def func():
5557

5658
async def func():
5759
await sleep(seconds=0) # TRIO115
60+
61+
62+
def func():
63+
import trio
64+
65+
if (walrus := 0) == 0:
66+
trio.sleep(walrus) # TRIO115

crates/ruff_linter/src/rules/flake8_trio/snapshots/ruff_linter__rules__flake8_trio__tests__TRIO115_TRIO115.py.snap

+69-72
Original file line numberDiff line numberDiff line change
@@ -143,47 +143,7 @@ TRIO115.py:29:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.s
143143
29 |+ trio.lowlevel.checkpoint() # TRIO115
144144
30 30 |
145145
31 31 | m_x, m_y = 0
146-
32 32 | trio.sleep(m_y) # TRIO115
147-
148-
TRIO115.py:32:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)`
149-
|
150-
31 | m_x, m_y = 0
151-
32 | trio.sleep(m_y) # TRIO115
152-
| ^^^^^^^^^^^^^^^ TRIO115
153-
33 | trio.sleep(m_x) # TRIO115
154-
|
155-
= help: Replace with `trio.lowlevel.checkpoint()`
156-
157-
Safe fix
158-
29 29 | trio.sleep(e) # TRIO115
159-
30 30 |
160-
31 31 | m_x, m_y = 0
161-
32 |- trio.sleep(m_y) # TRIO115
162-
32 |+ trio.lowlevel.checkpoint() # TRIO115
163-
33 33 | trio.sleep(m_x) # TRIO115
164-
34 34 |
165-
35 35 | m_a = m_b = 0
166-
167-
TRIO115.py:33:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)`
168-
|
169-
31 | m_x, m_y = 0
170-
32 | trio.sleep(m_y) # TRIO115
171-
33 | trio.sleep(m_x) # TRIO115
172-
| ^^^^^^^^^^^^^^^ TRIO115
173-
34 |
174-
35 | m_a = m_b = 0
175-
|
176-
= help: Replace with `trio.lowlevel.checkpoint()`
177-
178-
Safe fix
179-
30 30 |
180-
31 31 | m_x, m_y = 0
181-
32 32 | trio.sleep(m_y) # TRIO115
182-
33 |- trio.sleep(m_x) # TRIO115
183-
33 |+ trio.lowlevel.checkpoint() # TRIO115
184-
34 34 |
185-
35 35 | m_a = m_b = 0
186-
36 36 | trio.sleep(m_a) # TRIO115
146+
32 32 | trio.sleep(m_y) # OK
187147

188148
TRIO115.py:36:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)`
189149
|
@@ -195,7 +155,7 @@ TRIO115.py:36:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.s
195155
= help: Replace with `trio.lowlevel.checkpoint()`
196156

197157
Safe fix
198-
33 33 | trio.sleep(m_x) # TRIO115
158+
33 33 | trio.sleep(m_x) # OK
199159
34 34 |
200160
35 35 | m_a = m_b = 0
201161
36 |- trio.sleep(m_a) # TRIO115
@@ -264,51 +224,88 @@ TRIO115.py:42:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.s
264224
44 44 |
265225
45 45 | def func():
266226

267-
TRIO115.py:53:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)`
227+
TRIO115.py:48:14: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)`
268228
|
269-
52 | def func():
270-
53 | sleep(0) # TRIO115
271-
| ^^^^^^^^ TRIO115
229+
46 | import trio
230+
47 |
231+
48 | trio.run(trio.sleep(0)) # TRIO115
232+
| ^^^^^^^^^^^^^ TRIO115
272233
|
273234
= help: Replace with `trio.lowlevel.checkpoint()`
274235

275236
Safe fix
276-
46 46 | trio.run(trio.sleep(0)) # TRIO115
237+
45 45 | def func():
238+
46 46 | import trio
277239
47 47 |
278-
48 48 |
279-
49 |-from trio import Event, sleep
280-
49 |+from trio import Event, sleep, lowlevel
240+
48 |- trio.run(trio.sleep(0)) # TRIO115
241+
48 |+ trio.run(trio.lowlevel.checkpoint()) # TRIO115
242+
49 49 |
243+
50 50 |
244+
51 51 | from trio import Event, sleep
245+
246+
TRIO115.py:55:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)`
247+
|
248+
54 | def func():
249+
55 | sleep(0) # TRIO115
250+
| ^^^^^^^^ TRIO115
251+
|
252+
= help: Replace with `trio.lowlevel.checkpoint()`
253+
254+
Safe fix
255+
48 48 | trio.run(trio.sleep(0)) # TRIO115
256+
49 49 |
281257
50 50 |
282-
51 51 |
283-
52 52 | def func():
284-
53 |- sleep(0) # TRIO115
285-
53 |+ lowlevel.checkpoint() # TRIO115
286-
54 54 |
287-
55 55 |
288-
56 56 | async def func():
258+
51 |-from trio import Event, sleep
259+
51 |+from trio import Event, sleep, lowlevel
260+
52 52 |
261+
53 53 |
262+
54 54 | def func():
263+
55 |- sleep(0) # TRIO115
264+
55 |+ lowlevel.checkpoint() # TRIO115
265+
56 56 |
266+
57 57 |
267+
58 58 | async def func():
289268

290-
TRIO115.py:57:11: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)`
269+
TRIO115.py:59:11: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)`
291270
|
292-
56 | async def func():
293-
57 | await sleep(seconds=0) # TRIO115
271+
58 | async def func():
272+
59 | await sleep(seconds=0) # TRIO115
294273
| ^^^^^^^^^^^^^^^^ TRIO115
295274
|
296275
= help: Replace with `trio.lowlevel.checkpoint()`
297276

298277
Safe fix
299-
46 46 | trio.run(trio.sleep(0)) # TRIO115
300-
47 47 |
301-
48 48 |
302-
49 |-from trio import Event, sleep
303-
49 |+from trio import Event, sleep, lowlevel
278+
48 48 | trio.run(trio.sleep(0)) # TRIO115
279+
49 49 |
304280
50 50 |
305-
51 51 |
306-
52 52 | def func():
281+
51 |-from trio import Event, sleep
282+
51 |+from trio import Event, sleep, lowlevel
283+
52 52 |
284+
53 53 |
285+
54 54 | def func():
307286
--------------------------------------------------------------------------------
308-
54 54 |
309-
55 55 |
310-
56 56 | async def func():
311-
57 |- await sleep(seconds=0) # TRIO115
312-
57 |+ await lowlevel.checkpoint() # TRIO115
287+
56 56 |
288+
57 57 |
289+
58 58 | async def func():
290+
59 |- await sleep(seconds=0) # TRIO115
291+
59 |+ await lowlevel.checkpoint() # TRIO115
292+
60 60 |
293+
61 61 |
294+
62 62 | def func():
295+
296+
TRIO115.py:66:9: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)`
297+
|
298+
65 | if (walrus := 0) == 0:
299+
66 | trio.sleep(walrus) # TRIO115
300+
| ^^^^^^^^^^^^^^^^^^ TRIO115
301+
|
302+
= help: Replace with `trio.lowlevel.checkpoint()`
303+
304+
Safe fix
305+
63 63 | import trio
306+
64 64 |
307+
65 65 | if (walrus := 0) == 0:
308+
66 |- trio.sleep(walrus) # TRIO115
309+
66 |+ trio.lowlevel.checkpoint() # TRIO115
313310

314311

crates/ruff_python_semantic/src/analyze/typing.rs

+50-32
Original file line numberDiff line numberDiff line change
@@ -582,42 +582,64 @@ pub fn resolve_assignment<'a>(
582582
pub fn find_assigned_value<'a>(symbol: &str, semantic: &'a SemanticModel<'a>) -> Option<&'a Expr> {
583583
let binding_id = semantic.lookup_symbol(symbol)?;
584584
let binding = semantic.binding(binding_id);
585-
if binding.kind.is_assignment() || binding.kind.is_named_expr_assignment() {
586-
let parent_id = binding.source?;
587-
let parent = semantic.statement(parent_id);
588-
match parent {
589-
Stmt::Assign(ast::StmtAssign { value, targets, .. }) => match value.as_ref() {
590-
Expr::Tuple(ast::ExprTuple { elts, .. })
591-
| Expr::List(ast::ExprList { elts, .. }) => {
585+
match binding.kind {
586+
// Ex) `x := 1`
587+
BindingKind::NamedExprAssignment => {
588+
let parent_id = binding.source?;
589+
let parent = semantic
590+
.expressions(parent_id)
591+
.find_map(|expr| expr.as_named_expr_expr());
592+
if let Some(ast::ExprNamedExpr { target, value, .. }) = parent {
593+
return match_value(symbol, target.as_ref(), value.as_ref());
594+
}
595+
}
596+
// Ex) `x = 1`
597+
BindingKind::Assignment => {
598+
let parent_id = binding.source?;
599+
let parent = semantic.statement(parent_id);
600+
match parent {
601+
Stmt::Assign(ast::StmtAssign { value, targets, .. }) => {
592602
if let Some(target) = targets.iter().find(|target| defines(symbol, target)) {
593-
return match target {
594-
Expr::Tuple(ast::ExprTuple {
595-
elts: target_elts, ..
596-
})
597-
| Expr::List(ast::ExprList {
598-
elts: target_elts, ..
599-
})
600-
| Expr::Set(ast::ExprSet {
601-
elts: target_elts, ..
602-
}) => get_value_by_id(symbol, target_elts, elts),
603-
_ => Some(value.as_ref()),
604-
};
603+
return match_value(symbol, target, value.as_ref());
605604
}
606605
}
607-
_ => return Some(value.as_ref()),
608-
},
609-
Stmt::AnnAssign(ast::StmtAnnAssign {
610-
value: Some(value), ..
611-
}) => {
612-
return Some(value.as_ref());
606+
Stmt::AnnAssign(ast::StmtAnnAssign {
607+
value: Some(value),
608+
target,
609+
..
610+
}) => {
611+
return match_value(symbol, target, value.as_ref());
612+
}
613+
_ => {}
613614
}
614-
Stmt::AugAssign(_) => return None,
615-
_ => return None,
616615
}
616+
_ => {}
617617
}
618618
None
619619
}
620620

621+
/// Given a target and value, find the value that's assigned to the given symbol.
622+
fn match_value<'a>(symbol: &str, target: &Expr, value: &'a Expr) -> Option<&'a Expr> {
623+
match target {
624+
Expr::Name(ast::ExprName { id, .. }) if id.as_str() == symbol => Some(value),
625+
Expr::Tuple(ast::ExprTuple { elts, .. }) | Expr::List(ast::ExprList { elts, .. }) => {
626+
match value {
627+
Expr::Tuple(ast::ExprTuple {
628+
elts: value_elts, ..
629+
})
630+
| Expr::List(ast::ExprList {
631+
elts: value_elts, ..
632+
})
633+
| Expr::Set(ast::ExprSet {
634+
elts: value_elts, ..
635+
}) => get_value_by_id(symbol, elts, value_elts),
636+
_ => None,
637+
}
638+
}
639+
_ => None,
640+
}
641+
}
642+
621643
/// Returns `true` if the [`Expr`] defines the symbol.
622644
fn defines(symbol: &str, expr: &Expr) -> bool {
623645
match expr {
@@ -629,11 +651,7 @@ fn defines(symbol: &str, expr: &Expr) -> bool {
629651
}
630652
}
631653

632-
fn get_value_by_id<'a>(
633-
target_id: &str,
634-
targets: &'a [Expr],
635-
values: &'a [Expr],
636-
) -> Option<&'a Expr> {
654+
fn get_value_by_id<'a>(target_id: &str, targets: &[Expr], values: &'a [Expr]) -> Option<&'a Expr> {
637655
for (target, value) in targets.iter().zip(values.iter()) {
638656
match target {
639657
Expr::Tuple(ast::ExprTuple {

crates/ruff_python_semantic/src/model.rs

+17
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,23 @@ impl<'a> SemanticModel<'a> {
10051005
.nth(1)
10061006
}
10071007

1008+
/// Return the [`Expr`] corresponding to the given [`NodeId`].
1009+
#[inline]
1010+
pub fn expression(&self, node_id: NodeId) -> &'a Expr {
1011+
self.nodes
1012+
.ancestor_ids(node_id)
1013+
.find_map(|id| self.nodes[id].as_expression())
1014+
.expect("No expression found")
1015+
}
1016+
1017+
/// Returns an [`Iterator`] over the expressions, starting from the given [`NodeId`].
1018+
/// through to any parents.
1019+
pub fn expressions(&self, node_id: NodeId) -> impl Iterator<Item = &'a Expr> + '_ {
1020+
self.nodes
1021+
.ancestor_ids(node_id)
1022+
.filter_map(move |id| self.nodes[id].as_expression())
1023+
}
1024+
10081025
/// Set the [`Globals`] for the current [`Scope`].
10091026
pub fn set_globals(&mut self, globals: Globals<'a>) {
10101027
// If any global bindings don't already exist in the global scope, add them.

0 commit comments

Comments
 (0)