diff --git a/ChangeLog b/ChangeLog index 98110ede8e..ed9739bb5b 100644 --- a/ChangeLog +++ b/ChangeLog @@ -12,6 +12,11 @@ What's New in astroid 2.6.3? ============================ Release date: TBA + +* Fix a bad inferenece type for yield values inside of a derived class. + + Closes PyCQA/astroid#1090 + * Fix a crash when the node is a 'Module' in the brain builtin inference Closes PyCQA/pylint#4671 diff --git a/astroid/bases.py b/astroid/bases.py index 7f375f8994..e44ee70bd4 100644 --- a/astroid/bases.py +++ b/astroid/bases.py @@ -26,7 +26,7 @@ import collections from astroid import context as contextmod -from astroid import util +from astroid import decorators, util from astroid.const import BUILTINS, PY310_PLUS from astroid.exceptions import ( AstroidTypeError, @@ -543,9 +543,14 @@ class Generator(BaseInstance): special_attributes = util.lazy_descriptor(objectmodel.GeneratorModel) - def __init__(self, parent=None): + def __init__(self, parent=None, generator_initial_context=None): super().__init__() self.parent = parent + self._call_context = contextmod.copy_context(generator_initial_context) + + @decorators.cached + def infer_yield_types(self): + yield from self.parent.infer_yield_result(self._call_context) def callable(self): return False diff --git a/astroid/protocols.py b/astroid/protocols.py index 228000ab18..4e2dc6312e 100644 --- a/astroid/protocols.py +++ b/astroid/protocols.py @@ -489,22 +489,8 @@ def _infer_context_manager(self, mgr, context): # It doesn't interest us. raise InferenceError(node=func) - # Get the first yield point. If it has multiple yields, - # then a RuntimeError will be raised. + yield next(inferred.infer_yield_types()) - possible_yield_points = func.nodes_of_class(nodes.Yield) - # Ignore yields in nested functions - yield_point = next( - (node for node in possible_yield_points if node.scope() == func), None - ) - if yield_point: - if not yield_point.value: - const = nodes.Const(None) - const.parent = yield_point - const.lineno = yield_point.lineno - yield const - else: - yield from yield_point.value.infer(context=context) elif isinstance(inferred, bases.Instance): try: enter = next(inferred.igetattr("__enter__", context=context)) diff --git a/astroid/scoped_nodes.py b/astroid/scoped_nodes.py index 09ed3910de..5fa890d94e 100644 --- a/astroid/scoped_nodes.py +++ b/astroid/scoped_nodes.py @@ -1708,6 +1708,21 @@ def is_generator(self): """ return bool(next(self._get_yield_nodes_skip_lambdas(), False)) + def infer_yield_result(self, context=None): + """Infer what the function yields when called + + :returns: What the function yields + :rtype: iterable(NodeNG or Uninferable) or None + """ + for yield_ in self.nodes_of_class(node_classes.Yield): + if yield_.value is None: + const = node_classes.Const(None) + const.parent = yield_ + const.lineno = yield_.lineno + yield const + elif yield_.scope() == self: + yield from yield_.value.infer(context=context) + def infer_call_result(self, caller=None, context=None): """Infer what the function returns when called. @@ -1719,7 +1734,7 @@ def infer_call_result(self, caller=None, context=None): generator_cls = bases.AsyncGenerator else: generator_cls = bases.Generator - result = generator_cls(self) + result = generator_cls(self, generator_initial_context=context) yield result return # This is really a gigantic hack to work around metaclass generators diff --git a/tests/unittest_inference.py b/tests/unittest_inference.py index 1fe83dc8c7..afc24dc28e 100644 --- a/tests/unittest_inference.py +++ b/tests/unittest_inference.py @@ -6154,5 +6154,26 @@ def test_issue926_binop_referencing_same_name_is_not_uninferable(): assert inferred[0].value == 3 +def test_issue_1090_infer_yield_type_base_class(): + code = """ +import contextlib + +class A: + @contextlib.contextmanager + def get(self): + yield self + +class B(A): + def play(): + pass + +with B().get() as b: + b +b + """ + node = extract_node(code) + assert next(node.infer()).pytype() == ".B" + + if __name__ == "__main__": unittest.main()