diff --git a/ChangeLog b/ChangeLog index 426a9b62b..bbfe7c7fd 100644 --- a/ChangeLog +++ b/ChangeLog @@ -24,6 +24,10 @@ Release date: TBA * Suppress ``SyntaxWarning`` for invalid escape sequences and return in finally on Python 3.14 when parsing modules. +* Assign ``Import`` and ``ImportFrom`` nodes to module locals if used with ``global``. + + Closes pylint-dev/pylint#10632 + What's New in astroid 4.0.0? ============================ diff --git a/astroid/builder.py b/astroid/builder.py index b2b723cda..f166ab492 100644 --- a/astroid/builder.py +++ b/astroid/builder.py @@ -16,10 +16,10 @@ import textwrap import types import warnings -from collections.abc import Iterator, Sequence +from collections.abc import Collection, Iterator, Sequence from io import TextIOWrapper from tokenize import detect_encoding -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from astroid import bases, modutils, nodes, raw_building, rebuilder, util from astroid._ast import ParserModule, get_parser_module @@ -163,11 +163,11 @@ def _post_build( module.file_encoding = encoding self._manager.cache_module(module) # post tree building steps after we stored the module in the cache: - for from_node in builder._import_from_nodes: + for from_node, global_names in builder._import_from_nodes: if from_node.modname == "__future__": for symbol, _ in from_node.names: module.future_imports.add(symbol) - self.add_from_names_to_locals(from_node) + self.add_from_names_to_locals(from_node, global_names) # handle delayed assattr nodes for delayed in builder._delayed_assattr: self.delayed_assattr(delayed) @@ -210,19 +210,23 @@ def _data_build( module = builder.visit_module(node, modname, node_file, package) return module, builder - def add_from_names_to_locals(self, node: nodes.ImportFrom) -> None: + def add_from_names_to_locals( + self, node: nodes.ImportFrom, global_name: Collection[str] + ) -> None: """Store imported names to the locals. Resort the locals if coming from a delayed node """ - def _key_func(node: nodes.NodeNG) -> int: - return node.fromlineno or 0 - - def sort_locals(my_list: list[nodes.NodeNG]) -> None: - my_list.sort(key=_key_func) + def add_local(parent_or_root: nodes.NodeNG, name: str) -> None: + parent_or_root.set_local(name, node) + my_list = parent_or_root.scope().locals[name] + if TYPE_CHECKING: + my_list = cast(list[nodes.NodeNG], my_list) + my_list.sort(key=lambda n: n.fromlineno or 0) assert node.parent # It should always default to the module + module = node.root() for name, asname in node.names: if name == "*": try: @@ -230,11 +234,16 @@ def sort_locals(my_list: list[nodes.NodeNG]) -> None: except AstroidBuildingError: continue for name in imported.public_names(): - node.parent.set_local(name, node) - sort_locals(node.parent.scope().locals[name]) # type: ignore[arg-type] + if name in global_name: + add_local(module, name) + else: + add_local(node.parent, name) else: - node.parent.set_local(asname or name, node) - sort_locals(node.parent.scope().locals[asname or name]) # type: ignore[arg-type] + name = asname or name + if name in global_name: + add_local(module, name) + else: + add_local(node.parent, name) def delayed_assattr(self, node: nodes.AssignAttr) -> None: """Visit an AssignAttr node. diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py index 4b1a1e415..97f3a390e 100644 --- a/astroid/rebuilder.py +++ b/astroid/rebuilder.py @@ -11,7 +11,7 @@ import ast import sys import token -from collections.abc import Callable, Generator +from collections.abc import Callable, Collection, Generator from io import StringIO from tokenize import TokenInfo, generate_tokens from typing import TYPE_CHECKING, Final, TypeVar, cast, overload @@ -61,7 +61,7 @@ def __init__( self._manager = manager self._data = data.split("\n") if data else None self._global_names: list[dict[str, list[nodes.Global]]] = [] - self._import_from_nodes: list[nodes.ImportFrom] = [] + self._import_from_nodes: list[tuple[nodes.ImportFrom, Collection[str]]] = [] self._delayed_assattr: list[nodes.AssignAttr] = [] self._visit_meths: dict[ type[ast.AST], Callable[[ast.AST, nodes.NodeNG], nodes.NodeNG] @@ -1099,7 +1099,9 @@ def visit_importfrom( parent=parent, ) # store From names to add them to locals after building - self._import_from_nodes.append(newnode) + self._import_from_nodes.append( + (newnode, self._global_names[-1].keys() if self._global_names else ()) + ) return newnode @overload @@ -1300,8 +1302,11 @@ def visit_import(self, node: ast.Import, parent: nodes.NodeNG) -> nodes.Import: ) # save import names in parent's locals: for name, asname in newnode.names: - name = asname or name - parent.set_local(name.split(".")[0], newnode) + name = (asname or name).split(".")[0] + if self._global_names and name in self._global_names[-1]: + parent.root().set_local(name, newnode) + else: + parent.set_local(name, newnode) return newnode def visit_joinedstr( diff --git a/tests/test_scoped_nodes.py b/tests/test_scoped_nodes.py index f3244c6d5..c95f53fd3 100644 --- a/tests/test_scoped_nodes.py +++ b/tests/test_scoped_nodes.py @@ -2803,6 +2803,31 @@ class First(object, object): #@ astroid["First"].slots() +def test_import_with_global() -> None: + code = builder.parse( + """ + def f1(): + global platform + from sys import platform as plat + platform = plat + + def f2(): + global os, RE, deque, VERSION, Path + import os + import re as RE + from collections import deque + from sys import version as VERSION + from pathlib import * + """ + ) + assert "platform" in code.locals + assert "os" in code.locals + assert "RE" in code.locals + assert "deque" in code.locals + assert "VERSION" in code.locals + assert "Path" in code.locals + + class TestFrameNodes: @staticmethod def test_frame_node():