Skip to content

Commit 87b55a6

Browse files
hippo91PCManticore
authored andcommitted
Avoid statement deletion in the _filter_stmts method of the LookupMixin class for PartialFunction
In the case where the node is a PartialFunction and its name is the same as the current statement's name, avoid the statement deletion. The problem was that a call to a function that has been previously called vit a functools.partial was wrongly inferred. The bug comes from the _filter_stmts method of the LookupMixin class. The deletion of the current statement should not be made in the case where the node is an instance of the PartialFunction class and if the node's name is the same as the statement's name. This change also extracts PartialFunction from brain_functools into astroid.objects so that we remove a circular import problem. Close pylint-dev/pylint#2588
1 parent c54ca8e commit 87b55a6

File tree

5 files changed

+50
-28
lines changed

5 files changed

+50
-28
lines changed

ChangeLog

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ What's New in astroid 2.2.0?
66
============================
77
Release Date: TBA
88

9+
* Fix a bug where a call to a function that has been previously called via
10+
functools.partial was wrongly inferred
11+
12+
Close PyCQA/pylint#2588
913

1014
* Fix a recursion error caused by inferring the ``slice`` builtin.
1115

astroid/brain/brain_functools.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from astroid import helpers
1313
from astroid.interpreter import objectmodel
1414
from astroid import MANAGER
15+
from astroid import objects
1516

1617

1718
LRU_CACHE = "functools.lru_cache"
@@ -98,31 +99,8 @@ def _functools_partial_inference(node, context=None):
9899
"wrapped function received unknown parameters"
99100
)
100101

101-
# Return a wrapped() object that can be used further for inference
102-
class PartialFunction(astroid.FunctionDef):
103-
104-
filled_positionals = len(call.positional_arguments[1:])
105-
filled_keywords = list(call.keyword_arguments)
106-
107-
def infer_call_result(self, caller=None, context=None):
108-
nonlocal call
109-
filled_args = call.positional_arguments[1:]
110-
filled_keywords = call.keyword_arguments
111-
112-
if context:
113-
current_passed_keywords = {
114-
keyword for (keyword, _) in context.callcontext.keywords
115-
}
116-
for keyword, value in filled_keywords.items():
117-
if keyword not in current_passed_keywords:
118-
context.callcontext.keywords.append((keyword, value))
119-
120-
call_context_args = context.callcontext.args or []
121-
context.callcontext.args = filled_args + call_context_args
122-
123-
return super().infer_call_result(caller=caller, context=context)
124-
125-
partial_function = PartialFunction(
102+
partial_function = objects.PartialFunction(
103+
call,
126104
name=inferred_wrapped_function.name,
127105
doc=inferred_wrapped_function.doc,
128106
lineno=inferred_wrapped_function.lineno,

astroid/node_classes.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,8 +1217,16 @@ def _filter_stmts(self, stmts, frame, offset):
12171217
# want to clear previous assignments if any (hence the test on
12181218
# optional_assign)
12191219
if not (optional_assign or are_exclusive(_stmts[pindex], node)):
1220-
del _stmt_parents[pindex]
1221-
del _stmts[pindex]
1220+
if (
1221+
# In case of partial function node, if the statement is different
1222+
# from the origin function then it can be deleted otherwise it should
1223+
# remain to be able to correctly infer the call to origin function.
1224+
not node.is_function
1225+
or node.qname() != "PartialFunction"
1226+
or node.name != _stmts[pindex].name
1227+
):
1228+
del _stmt_parents[pindex]
1229+
del _stmts[pindex]
12221230
if isinstance(node, AssignName):
12231231
if not optional_assign and stmt.parent is mystmt.parent:
12241232
_stmts = []

astroid/objects.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,35 @@ class DictValues(bases.Proxy):
248248
__repr__ = node_classes.NodeNG.__repr__
249249

250250

251+
class PartialFunction(scoped_nodes.FunctionDef):
252+
"""A class representing partial function obtained via functools.partial"""
253+
254+
def __init__(
255+
self, call, name=None, doc=None, lineno=None, col_offset=None, parent=None
256+
):
257+
super().__init__(name, doc, lineno, col_offset, parent)
258+
self.filled_positionals = len(call.positional_arguments[1:])
259+
self.filled_args = call.positional_arguments[1:]
260+
self.filled_keywords = call.keyword_arguments
261+
262+
def infer_call_result(self, caller=None, context=None):
263+
if context:
264+
current_passed_keywords = {
265+
keyword for (keyword, _) in context.callcontext.keywords
266+
}
267+
for keyword, value in self.filled_keywords.items():
268+
if keyword not in current_passed_keywords:
269+
context.callcontext.keywords.append((keyword, value))
270+
271+
call_context_args = context.callcontext.args or []
272+
context.callcontext.args = self.filled_args + call_context_args
273+
274+
return super().infer_call_result(caller=caller, context=context)
275+
276+
def qname(self):
277+
return self.__class__.__name__
278+
279+
251280
# TODO: Hack to solve the circular import problem between node_classes and objects
252281
# This is not needed in 2.0, which has a cleaner design overall
253282
node_classes.Dict.__bases__ = (node_classes.NodeNG, DictInstance)

astroid/tests/unittest_brain.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1776,9 +1776,12 @@ def other_test(a, b, *, c=1):
17761776
partial(other_test, c=4)(1, 3) #@
17771777
partial(other_test, 4, c=4)(4) #@
17781778
partial(other_test, 4, c=4)(b=5) #@
1779+
test(1, 2) #@
1780+
partial(other_test, 1, 2)(c=3) #@
1781+
partial(test, b=4)(a=3) #@
17791782
"""
17801783
)
1781-
expected_values = [4, 7, 7, 3, 12, 16, 32, 36]
1784+
expected_values = [4, 7, 7, 3, 12, 16, 32, 36, 3, 9, 7]
17821785
for node, expected_value in zip(ast_nodes, expected_values):
17831786
inferred = next(node.infer())
17841787
assert isinstance(inferred, astroid.Const)

0 commit comments

Comments
 (0)