Skip to content

Commit 386df4c

Browse files
cdce8ppylint-backport[bot]
authored andcommitted
Assign import nodes to module locals if used with global (#2856)
(cherry picked from commit 1afbca6)
1 parent d1bbd35 commit 386df4c

File tree

4 files changed

+62
-19
lines changed

4 files changed

+62
-19
lines changed

ChangeLog

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ Release date: TBA
1717
* Suppress ``SyntaxWarning`` for invalid escape sequences and return in finally on
1818
Python 3.14 when parsing modules.
1919

20+
* Assign ``Import`` and ``ImportFrom`` nodes to module locals if used with ``global``.
21+
22+
Closes pylint-dev/pylint#10632
23+
2024

2125
What's New in astroid 4.0.0?
2226
============================

astroid/builder.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
import textwrap
1717
import types
1818
import warnings
19-
from collections.abc import Iterator, Sequence
19+
from collections.abc import Collection, Iterator, Sequence
2020
from io import TextIOWrapper
2121
from tokenize import detect_encoding
22-
from typing import TYPE_CHECKING
22+
from typing import TYPE_CHECKING, cast
2323

2424
from astroid import bases, modutils, nodes, raw_building, rebuilder, util
2525
from astroid._ast import ParserModule, get_parser_module
@@ -163,11 +163,11 @@ def _post_build(
163163
module.file_encoding = encoding
164164
self._manager.cache_module(module)
165165
# post tree building steps after we stored the module in the cache:
166-
for from_node in builder._import_from_nodes:
166+
for from_node, global_names in builder._import_from_nodes:
167167
if from_node.modname == "__future__":
168168
for symbol, _ in from_node.names:
169169
module.future_imports.add(symbol)
170-
self.add_from_names_to_locals(from_node)
170+
self.add_from_names_to_locals(from_node, global_names)
171171
# handle delayed assattr nodes
172172
for delayed in builder._delayed_assattr:
173173
self.delayed_assattr(delayed)
@@ -210,31 +210,40 @@ def _data_build(
210210
module = builder.visit_module(node, modname, node_file, package)
211211
return module, builder
212212

213-
def add_from_names_to_locals(self, node: nodes.ImportFrom) -> None:
213+
def add_from_names_to_locals(
214+
self, node: nodes.ImportFrom, global_name: Collection[str]
215+
) -> None:
214216
"""Store imported names to the locals.
215217
216218
Resort the locals if coming from a delayed node
217219
"""
218220

219-
def _key_func(node: nodes.NodeNG) -> int:
220-
return node.fromlineno or 0
221-
222-
def sort_locals(my_list: list[nodes.NodeNG]) -> None:
223-
my_list.sort(key=_key_func)
221+
def add_local(parent_or_root: nodes.NodeNG, name: str) -> None:
222+
parent_or_root.set_local(name, node)
223+
my_list = parent_or_root.scope().locals[name]
224+
if TYPE_CHECKING:
225+
my_list = cast(list[nodes.NodeNG], my_list)
226+
my_list.sort(key=lambda n: n.fromlineno or 0)
224227

225228
assert node.parent # It should always default to the module
229+
module = node.root()
226230
for name, asname in node.names:
227231
if name == "*":
228232
try:
229233
imported = node.do_import_module()
230234
except AstroidBuildingError:
231235
continue
232236
for name in imported.public_names():
233-
node.parent.set_local(name, node)
234-
sort_locals(node.parent.scope().locals[name]) # type: ignore[arg-type]
237+
if name in global_name:
238+
add_local(module, name)
239+
else:
240+
add_local(node.parent, name)
235241
else:
236-
node.parent.set_local(asname or name, node)
237-
sort_locals(node.parent.scope().locals[asname or name]) # type: ignore[arg-type]
242+
name = asname or name
243+
if name in global_name:
244+
add_local(module, name)
245+
else:
246+
add_local(node.parent, name)
238247

239248
def delayed_assattr(self, node: nodes.AssignAttr) -> None:
240249
"""Visit an AssignAttr node.

astroid/rebuilder.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import ast
1212
import sys
1313
import token
14-
from collections.abc import Callable, Generator
14+
from collections.abc import Callable, Collection, Generator
1515
from io import StringIO
1616
from tokenize import TokenInfo, generate_tokens
1717
from typing import TYPE_CHECKING, Final, TypeVar, cast, overload
@@ -61,7 +61,7 @@ def __init__(
6161
self._manager = manager
6262
self._data = data.split("\n") if data else None
6363
self._global_names: list[dict[str, list[nodes.Global]]] = []
64-
self._import_from_nodes: list[nodes.ImportFrom] = []
64+
self._import_from_nodes: list[tuple[nodes.ImportFrom, Collection[str]]] = []
6565
self._delayed_assattr: list[nodes.AssignAttr] = []
6666
self._visit_meths: dict[
6767
type[ast.AST], Callable[[ast.AST, nodes.NodeNG], nodes.NodeNG]
@@ -1099,7 +1099,9 @@ def visit_importfrom(
10991099
parent=parent,
11001100
)
11011101
# store From names to add them to locals after building
1102-
self._import_from_nodes.append(newnode)
1102+
self._import_from_nodes.append(
1103+
(newnode, self._global_names[-1].keys() if self._global_names else ())
1104+
)
11031105
return newnode
11041106

11051107
@overload
@@ -1300,8 +1302,11 @@ def visit_import(self, node: ast.Import, parent: nodes.NodeNG) -> nodes.Import:
13001302
)
13011303
# save import names in parent's locals:
13021304
for name, asname in newnode.names:
1303-
name = asname or name
1304-
parent.set_local(name.split(".")[0], newnode)
1305+
name = (asname or name).split(".")[0]
1306+
if self._global_names and name in self._global_names[-1]:
1307+
parent.root().set_local(name, newnode)
1308+
else:
1309+
parent.set_local(name, newnode)
13051310
return newnode
13061311

13071312
def visit_joinedstr(

tests/test_scoped_nodes.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2803,6 +2803,31 @@ class First(object, object): #@
28032803
astroid["First"].slots()
28042804

28052805

2806+
def test_import_with_global() -> None:
2807+
code = builder.parse(
2808+
"""
2809+
def f1():
2810+
global platform
2811+
from sys import platform as plat
2812+
platform = plat
2813+
2814+
def f2():
2815+
global os, RE, deque, VERSION, Path
2816+
import os
2817+
import re as RE
2818+
from collections import deque
2819+
from sys import version as VERSION
2820+
from pathlib import *
2821+
"""
2822+
)
2823+
assert "platform" in code.locals
2824+
assert "os" in code.locals
2825+
assert "RE" in code.locals
2826+
assert "deque" in code.locals
2827+
assert "VERSION" in code.locals
2828+
assert "Path" in code.locals
2829+
2830+
28062831
class TestFrameNodes:
28072832
@staticmethod
28082833
def test_frame_node():

0 commit comments

Comments
 (0)