Skip to content
This repository was archived by the owner on Jul 11, 2022. It is now read-only.

Commit dd8bde6

Browse files
committed
Improve get_future_imports implementation.
Closes pytest-dev#389.
1 parent 3bdd423 commit dd8bde6

File tree

3 files changed

+33
-12
lines changed

3 files changed

+33
-12
lines changed

black.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Callable,
2121
Collection,
2222
Dict,
23+
Generator,
2324
Generic,
2425
Iterable,
2526
Iterator,
@@ -2910,7 +2911,23 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
29102911

29112912
def get_future_imports(node: Node) -> Set[str]:
29122913
"""Return a set of __future__ imports in the file."""
2913-
imports = set()
2914+
imports: Set[str] = set()
2915+
2916+
def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
2917+
for child in children:
2918+
if isinstance(child, Leaf):
2919+
if child.type == token.NAME:
2920+
yield child.value
2921+
elif child.type == syms.import_as_name:
2922+
orig_name = child.children[0]
2923+
assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
2924+
assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
2925+
yield orig_name.value
2926+
elif child.type == syms.import_as_names:
2927+
yield from get_imports_from_children(child.children)
2928+
else:
2929+
assert False, "Invalid syntax parsing imports"
2930+
29142931
for child in node.children:
29152932
if child.type != syms.simple_stmt:
29162933
break
@@ -2929,15 +2946,7 @@ def get_future_imports(node: Node) -> Set[str]:
29292946
module_name = first_child.children[1]
29302947
if not isinstance(module_name, Leaf) or module_name.value != "__future__":
29312948
break
2932-
for import_from_child in first_child.children[3:]:
2933-
if isinstance(import_from_child, Leaf):
2934-
if import_from_child.type == token.NAME:
2935-
imports.add(import_from_child.value)
2936-
else:
2937-
assert import_from_child.type == syms.import_as_names
2938-
for leaf in import_from_child.children:
2939-
if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2940-
imports.add(leaf.value)
2949+
imports |= set(get_imports_from_children(first_child.children[3:]))
29412950
else:
29422951
break
29432952
return imports

tests/data/python2_unicode_literals.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python2
2-
from __future__ import unicode_literals
2+
from __future__ import unicode_literals as _unicode_literals
3+
from __future__ import absolute_import
4+
from __future__ import print_function as lol, with_function
35

46
u'hello'
57
U"hello"
@@ -9,7 +11,9 @@
911

1012

1113
#!/usr/bin/env python2
12-
from __future__ import unicode_literals
14+
from __future__ import unicode_literals as _unicode_literals
15+
from __future__ import absolute_import
16+
from __future__ import print_function as lol, with_function
1317

1418
"hello"
1519
"hello"

tests/test_black.py

+8
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,14 @@ def test_get_future_imports(self) -> None:
735735
self.assertEqual(set(), black.get_future_imports(node))
736736
node = black.lib2to3_parse("from some.module import black\n")
737737
self.assertEqual(set(), black.get_future_imports(node))
738+
node = black.lib2to3_parse(
739+
"from __future__ import unicode_literals as _unicode_literals"
740+
)
741+
self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
742+
node = black.lib2to3_parse(
743+
"from __future__ import unicode_literals as _lol, print"
744+
)
745+
self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
738746

739747
def test_debug_visitor(self) -> None:
740748
source, _ = read_data("debug_visitor.py")

0 commit comments

Comments
 (0)