diff --git a/ChangeLog b/ChangeLog index b4d9fae880..b6e54d7fbc 100644 --- a/ChangeLog +++ b/ChangeLog @@ -60,6 +60,9 @@ Release date: TBA * On Python versions >= 3.9, ``astroid`` now understands subscripting builtin classes such as ``enumerate`` or ``staticmethod``. +* Instances of NamedTuples both from typing and collections will now be cast to a + NamedTuple instance. This instance proxies the definition of that NamedTuple. + * Fixed inference of ``Enums`` when they are imported under an alias. Closes PyCQA/pylint#5776 diff --git a/astroid/bases.py b/astroid/bases.py index a3f2017f3b..ae644e95b1 100644 --- a/astroid/bases.py +++ b/astroid/bases.py @@ -621,3 +621,7 @@ def __repr__(self): def __str__(self): return f"AsyncGenerator({self._proxied.name})" + + +class NamedTuple(BaseInstance): + """Special node representing a NamedTuple instance""" diff --git a/astroid/brain/brain_namedtuple_enum.py b/astroid/brain/brain_namedtuple_enum.py index 94b0ae7936..e45f739643 100644 --- a/astroid/brain/brain_namedtuple_enum.py +++ b/astroid/brain/brain_namedtuple_enum.py @@ -186,21 +186,30 @@ def infer_named_tuple( node: nodes.Call, context: InferenceContext | None = None ) -> Iterator[nodes.ClassDef]: """Specific inference function for namedtuple Call node""" - tuple_base_name = nodes.Name(name="tuple", parent=node.root()) - class_node, name, attributes = infer_func_form( - node, tuple_base_name, context=context - ) + # Infer which type of NamedTuple we're dealing with (typing or collections) + inferred_namedtuple_call = next(node.func.infer()) + base_names = [ + nodes.Name(name="tuple", parent=node.root()), + nodes.Name( + name=inferred_namedtuple_call.name, + parent=inferred_namedtuple_call.root(), + ), + ] + + class_node, name, attributes = infer_func_form(node, base_names, context=context) call_site = arguments.CallSite.from_call(node, context=context) - node = extract_node("import collections; collections.namedtuple") - try: - func = next(node.infer()) - except StopIteration as e: - raise InferenceError(node=node) from e try: - rename = next(call_site.infer_argument(func, "rename", context)).bool_value() + rename = next( + call_site.infer_argument(inferred_namedtuple_call, "rename", context) + ).bool_value() except (InferenceError, StopIteration): rename = False + # If inferred_namedtuple_call is the ClassDef of typing.NamedTuple + # infer_argument will raise AttributeError + # TODO: See if this exception can be prevented + except AttributeError: + rename = False try: attributes = _check_namedtuple_attributes(name, attributes, rename) @@ -331,7 +340,7 @@ def value(self): __members__ = [''] """ ) - class_node = infer_func_form(node, enum_meta, context=context, enum=True)[0] + class_node = infer_func_form(node, [enum_meta], context=context, enum=True)[0] return iter([class_node.instantiate_class()]) @@ -509,9 +518,11 @@ def infer_typing_namedtuple_function(node, context=None): def infer_typing_namedtuple( node: nodes.Call, context: InferenceContext | None = None ) -> Iterator[nodes.ClassDef]: - """Infer a typing.NamedTuple(...) call.""" - # This is essentially a namedtuple with different arguments - # so we extract the args and infer a named tuple. + """Infer a typing.NamedTuple(...) call. + + We do some premature checking of the node to see if we don't run into any unexpected + values. + """ try: func = next(node.func.infer()) except (InferenceError, StopIteration) as exc: diff --git a/astroid/const.py b/astroid/const.py index 87682c942f..477887d9c8 100644 --- a/astroid/const.py +++ b/astroid/const.py @@ -6,6 +6,11 @@ import sys from pathlib import Path +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + PY38 = sys.version_info[:2] == (3, 8) PY38_PLUS = sys.version_info >= (3, 8) PY39_PLUS = sys.version_info >= (3, 9) @@ -33,3 +38,6 @@ class Context(enum.Enum): ASTROID_INSTALL_DIRECTORY = Path(__file__).parent BRAIN_MODULES_DIRECTORY = ASTROID_INSTALL_DIRECTORY / "brain" + +NAMEDTUPLE_BASENAMES: Final[frozenset[str]] = frozenset(("namedtuple", "NamedTuple")) +"""Const used to identify namedtuples in the basenames of subclasses""" diff --git a/astroid/filter_statements.py b/astroid/filter_statements.py index f649cdb8bb..0466d72bd3 100644 --- a/astroid/filter_statements.py +++ b/astroid/filter_statements.py @@ -78,27 +78,27 @@ def _filter_stmts(base_node: nodes.NodeNG, stmts, frame, offset): # # def test(b=1): # ... + # + # If the frame is already a Module we don't need to go up anymore if ( - base_node.parent + not isinstance(myframe, nodes.Module) + and base_node.parent and base_node.statement(future=True) is myframe and myframe.parent ): myframe = myframe.parent.frame() + # We can use line filtering if we are in the same frame. + # mylineno is 0 by default to skip if we can't determine lineno + # or if we are at the module level. lineno information is (for example) + # missing for nodes inserted for living objects. + mylineno = 0 mystmt: nodes.Statement | None = None - if base_node.parent: + if base_node.parent and not isinstance(base_node.parent, nodes.Module): mystmt = base_node.statement(future=True) - - # line filtering if we are in the same frame - # - # take care node may be missing lineno information (this is the case for - # nodes inserted for living objects) - if myframe is frame and mystmt and mystmt.fromlineno is not None: - assert mystmt.fromlineno is not None, mystmt - mylineno = mystmt.fromlineno + offset - else: - # disabling lineno filtering - mylineno = 0 + if myframe is frame and mystmt.fromlineno is not None: + assert mystmt.fromlineno is not None, mystmt + mylineno = mystmt.fromlineno + offset _stmts = [] _stmt_parents = [] diff --git a/astroid/nodes/scoped_nodes/scoped_nodes.py b/astroid/nodes/scoped_nodes/scoped_nodes.py index b027d228ff..2e7d113668 100644 --- a/astroid/nodes/scoped_nodes/scoped_nodes.py +++ b/astroid/nodes/scoped_nodes/scoped_nodes.py @@ -21,7 +21,7 @@ from astroid import bases from astroid import decorators as decorators_mod from astroid import util -from astroid.const import IS_PYPY, PY38, PY38_PLUS, PY39_PLUS +from astroid.const import IS_PYPY, NAMEDTUPLE_BASENAMES, PY38, PY38_PLUS, PY39_PLUS from astroid.context import ( CallContext, InferenceContext, @@ -2521,6 +2521,8 @@ def instantiate_class(self): return objects.ExceptionInstance(self) except MroError: pass + if any(i in NAMEDTUPLE_BASENAMES for i in self.basenames): + return bases.NamedTuple(self) return bases.Instance(self) def getattr(self, name, context=None, class_context=True): diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py index 1eab27639f..dd81a8cd36 100644 --- a/tests/unittest_brain.py +++ b/tests/unittest_brain.py @@ -1502,7 +1502,8 @@ class X(NamedTuple("X", [("a", int), ("b", str), ("c", bytes)])): """ ) self.assertEqual( - [anc.name for anc in klass.ancestors()], ["X", "tuple", "object"] + [anc.name for anc in klass.ancestors()], + ["X", "tuple", "object", "NamedTuple"], ) for anc in klass.ancestors(): self.assertFalse(anc.parent is None) @@ -1611,7 +1612,7 @@ class Example(NamedTuple): """ ) inferred = next(result.infer()) - self.assertIsInstance(inferred, astroid.Instance) + self.assertIsInstance(inferred, bases.NamedTuple) class_attr = inferred.getattr("CLASS_ATTR")[0] self.assertIsInstance(class_attr, astroid.AssignName) @@ -1784,7 +1785,7 @@ def test_typing_namedtuple_dont_crash_on_no_fields(self) -> None: """ ) inferred = next(node.infer()) - self.assertIsInstance(inferred, astroid.Instance) + self.assertIsInstance(inferred, bases.NamedTuple) @test_utils.require_version("3.8") def test_typed_dict(self): @@ -3131,7 +3132,7 @@ def test_http_client_brain() -> None: """ ) inferred = next(node.infer()) - assert isinstance(inferred, astroid.Instance) + assert isinstance(inferred, bases.NamedTuple) def test_http_status_brain() -> None: diff --git a/tests/unittest_inference.py b/tests/unittest_inference.py index 59344b8524..1377705a8c 100644 --- a/tests/unittest_inference.py +++ b/tests/unittest_inference.py @@ -17,7 +17,7 @@ import pytest -from astroid import Slice, arguments +from astroid import Slice, arguments, bases from astroid import decorators as decoratorsmod from astroid import helpers, nodes, objects, test_utils, util from astroid.arguments import CallSite @@ -2193,8 +2193,8 @@ def collections(self): """ ast = parse(code, __name__) - bases = ast["Second"].bases[0] - inferred = next(bases.infer()) + base_classes = ast["Second"].bases[0] + inferred = next(base_classes.infer()) self.assertTrue(inferred) self.assertIsInstance(inferred, nodes.ClassDef) self.assertEqual(inferred.qname(), "collections.Counter") @@ -6249,7 +6249,7 @@ def test_inferaugassign_picking_parent_instead_of_stmt() -> None: # as a string. node = extract_node(code) inferred = next(node.infer()) - assert isinstance(inferred, Instance) + assert isinstance(inferred, bases.NamedTuple) assert inferred.name == "SomeClass" diff --git a/tests/unittest_object_model.py b/tests/unittest_object_model.py index 3cf2e4aee3..9d654530de 100644 --- a/tests/unittest_object_model.py +++ b/tests/unittest_object_model.py @@ -8,7 +8,7 @@ import pytest import astroid -from astroid import builder, nodes, objects, test_utils, util +from astroid import bases, builder, nodes, objects, test_utils, util from astroid.const import PY311_PLUS from astroid.exceptions import InferenceError @@ -203,9 +203,9 @@ class C(A): pass called_mro = next(ast_nodes[5].infer()) self.assertEqual(called_mro.elts, mro.elts) - bases = next(ast_nodes[6].infer()) - self.assertIsInstance(bases, astroid.Tuple) - self.assertEqual([cls.name for cls in bases.elts], ["object"]) + bases_classes = next(ast_nodes[6].infer()) + self.assertIsInstance(bases_classes, astroid.Tuple) + self.assertEqual([cls.name for cls in bases_classes.elts], ["object"]) cls = next(ast_nodes[7].infer()) self.assertIsInstance(cls, astroid.ClassDef) @@ -694,7 +694,7 @@ def foo(): self.assertIsInstance(wrapped, astroid.FunctionDef) self.assertEqual(wrapped.name, "foo") cache_info = next(ast_nodes[2].infer()) - self.assertIsInstance(cache_info, astroid.Instance) + self.assertIsInstance(cache_info, bases.NamedTuple) if __name__ == "__main__":