Skip to content

Commit ac95965

Browse files
authored
Fix incorrect scope for functools partials (#1097)
* Use scope parent * Do not set name onto parent frame for partials * Add test case that captures broken scopes
1 parent e178626 commit ac95965

File tree

4 files changed

+47
-12
lines changed

4 files changed

+47
-12
lines changed

astroid/brain/brain_functools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _functools_partial_inference(node, context=None):
103103
doc=inferred_wrapped_function.doc,
104104
lineno=inferred_wrapped_function.lineno,
105105
col_offset=inferred_wrapped_function.col_offset,
106-
parent=inferred_wrapped_function.parent,
106+
parent=node.parent,
107107
)
108108
partial_function.postinit(
109109
args=inferred_wrapped_function.args,

astroid/node_classes.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,16 +1209,8 @@ def _filter_stmts(self, stmts, frame, offset):
12091209
# want to clear previous assignments if any (hence the test on
12101210
# optional_assign)
12111211
if not (optional_assign or are_exclusive(_stmts[pindex], node)):
1212-
if (
1213-
# In case of partial function node, if the statement is different
1214-
# from the origin function then it can be deleted otherwise it should
1215-
# remain to be able to correctly infer the call to origin function.
1216-
not node.is_function
1217-
or node.qname() != "PartialFunction"
1218-
or node.name != _stmts[pindex].name
1219-
):
1220-
del _stmt_parents[pindex]
1221-
del _stmts[pindex]
1212+
del _stmt_parents[pindex]
1213+
del _stmts[pindex]
12221214
if isinstance(node, AssignName):
12231215
if not optional_assign and stmt.parent is mystmt.parent:
12241216
_stmts = []

astroid/objects.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,10 @@ class PartialFunction(scoped_nodes.FunctionDef):
260260
def __init__(
261261
self, call, name=None, doc=None, lineno=None, col_offset=None, parent=None
262262
):
263-
super().__init__(name, doc, lineno, col_offset, parent)
263+
super().__init__(name, doc, lineno, col_offset, parent=None)
264+
# A typical FunctionDef automatically adds its name to the parent scope,
265+
# but a partial should not, so defer setting parent until after init
266+
self.parent = parent
264267
self.filled_positionals = len(call.positional_arguments[1:])
265268
self.filled_args = call.positional_arguments[1:]
266269
self.filled_keywords = call.keyword_arguments

tests/unittest_brain.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2761,6 +2761,46 @@ def other_test(a, b, *, c=1):
27612761
assert isinstance(inferred, astroid.Const)
27622762
assert inferred.value == expected_value
27632763

2764+
def test_partial_assignment(self):
2765+
"""Make sure partials are not assigned to original scope."""
2766+
ast_nodes = astroid.extract_node(
2767+
"""
2768+
from functools import partial
2769+
def test(a, b): #@
2770+
return a + b
2771+
test2 = partial(test, 1)
2772+
test2 #@
2773+
def test3_scope(a):
2774+
test3 = partial(test, a)
2775+
test3 #@
2776+
"""
2777+
)
2778+
func1, func2, func3 = ast_nodes
2779+
assert func1.parent.scope() == func2.parent.scope()
2780+
assert func1.parent.scope() != func3.parent.scope()
2781+
partial_func3 = next(func3.infer())
2782+
# use scope of parent, so that it doesn't just refer to self
2783+
scope = partial_func3.parent.scope()
2784+
assert scope.name == "test3_scope", "parented by closure"
2785+
2786+
def test_partial_does_not_affect_scope(self):
2787+
"""Make sure partials are not automatically assigned."""
2788+
ast_nodes = astroid.extract_node(
2789+
"""
2790+
from functools import partial
2791+
def test(a, b):
2792+
return a + b
2793+
def scope():
2794+
test2 = partial(test, 1)
2795+
test2 #@
2796+
"""
2797+
)
2798+
test2 = next(ast_nodes.infer())
2799+
mod_scope = test2.root()
2800+
scope = test2.parent.scope()
2801+
assert set(mod_scope) == {"test", "scope", "partial"}
2802+
assert set(scope) == {"test2"}
2803+
27642804

27652805
def test_http_client_brain():
27662806
node = astroid.extract_node(

0 commit comments

Comments
 (0)