Skip to content

Commit 6af7d11

Browse files
Fix AST safety check false negative (#4270)
Fixes #4268 Previously we would allow whitespace changes in all strings, now only in docstrings. Co-authored-by: Shantanu <[email protected]>
1 parent f03ee11 commit 6af7d11

File tree

4 files changed

+156
-27
lines changed

4 files changed

+156
-27
lines changed

CHANGES.md

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
<!-- Changes that affect Black's stable style -->
1212

1313
- Don't move comments along with delimiters, which could cause crashes (#4248)
14+
- Strengthen AST safety check to catch more unsafe changes to strings. Previous versions
15+
of Black would incorrectly format the contents of certain unusual f-strings containing
16+
nested strings with the same quote type. Now, Black will crash on such strings until
17+
support for the new f-string syntax is implemented. (#4270)
1418

1519
### Preview style
1620

src/black/__init__.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,13 @@
7777
syms,
7878
)
7979
from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
80-
from black.parsing import InvalidInput # noqa F401
81-
from black.parsing import lib2to3_parse, parse_ast, stringify_ast
80+
from black.parsing import ( # noqa F401
81+
ASTSafetyError,
82+
InvalidInput,
83+
lib2to3_parse,
84+
parse_ast,
85+
stringify_ast,
86+
)
8287
from black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges
8388
from black.report import Changed, NothingChanged, Report
8489
from black.trans import iter_fexpr_spans
@@ -1511,7 +1516,7 @@ def assert_equivalent(src: str, dst: str) -> None:
15111516
try:
15121517
src_ast = parse_ast(src)
15131518
except Exception as exc:
1514-
raise AssertionError(
1519+
raise ASTSafetyError(
15151520
"cannot use --safe with this file; failed to parse source file AST: "
15161521
f"{exc}\n"
15171522
"This could be caused by running Black with an older Python version "
@@ -1522,7 +1527,7 @@ def assert_equivalent(src: str, dst: str) -> None:
15221527
dst_ast = parse_ast(dst)
15231528
except Exception as exc:
15241529
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1525-
raise AssertionError(
1530+
raise ASTSafetyError(
15261531
f"INTERNAL ERROR: Black produced invalid code: {exc}. "
15271532
"Please report a bug on https://github.com/psf/black/issues. "
15281533
f"This invalid output might be helpful: {log}"
@@ -1532,7 +1537,7 @@ def assert_equivalent(src: str, dst: str) -> None:
15321537
dst_ast_str = "\n".join(stringify_ast(dst_ast))
15331538
if src_ast_str != dst_ast_str:
15341539
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1535-
raise AssertionError(
1540+
raise ASTSafetyError(
15361541
"INTERNAL ERROR: Black produced code that is not equivalent to the"
15371542
" source. Please report a bug on "
15381543
f"https://github.com/psf/black/issues. This diff might be helpful: {log}"

src/black/parsing.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ def lib2to3_unparse(node: Node) -> str:
110110
return code
111111

112112

113+
class ASTSafetyError(Exception):
114+
"""Raised when Black's generated code is not equivalent to the old AST."""
115+
116+
113117
def _parse_single_version(
114118
src: str, version: Tuple[int, int], *, type_comments: bool
115119
) -> ast.AST:
@@ -154,9 +158,20 @@ def _normalize(lineend: str, value: str) -> str:
154158
return normalized.strip()
155159

156160

157-
def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
161+
def stringify_ast(node: ast.AST) -> Iterator[str]:
158162
"""Simple visitor generating strings to compare ASTs by content."""
163+
return _stringify_ast(node, [])
164+
159165

166+
def _stringify_ast_with_new_parent(
167+
node: ast.AST, parent_stack: List[ast.AST], new_parent: ast.AST
168+
) -> Iterator[str]:
169+
parent_stack.append(new_parent)
170+
yield from _stringify_ast(node, parent_stack)
171+
parent_stack.pop()
172+
173+
174+
def _stringify_ast(node: ast.AST, parent_stack: List[ast.AST]) -> Iterator[str]:
160175
if (
161176
isinstance(node, ast.Constant)
162177
and isinstance(node.value, str)
@@ -167,7 +182,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
167182
# over the kind
168183
node.kind = None
169184

170-
yield f"{' ' * depth}{node.__class__.__name__}("
185+
yield f"{' ' * len(parent_stack)}{node.__class__.__name__}("
171186

172187
for field in sorted(node._fields): # noqa: F402
173188
# TypeIgnore has only one field 'lineno' which breaks this comparison
@@ -179,7 +194,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
179194
except AttributeError:
180195
continue
181196

182-
yield f"{' ' * (depth + 1)}{field}="
197+
yield f"{' ' * (len(parent_stack) + 1)}{field}="
183198

184199
if isinstance(value, list):
185200
for item in value:
@@ -191,20 +206,28 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
191206
and isinstance(item, ast.Tuple)
192207
):
193208
for elt in item.elts:
194-
yield from stringify_ast(elt, depth + 2)
209+
yield from _stringify_ast_with_new_parent(
210+
elt, parent_stack, node
211+
)
195212

196213
elif isinstance(item, ast.AST):
197-
yield from stringify_ast(item, depth + 2)
214+
yield from _stringify_ast_with_new_parent(item, parent_stack, node)
198215

199216
elif isinstance(value, ast.AST):
200-
yield from stringify_ast(value, depth + 2)
217+
yield from _stringify_ast_with_new_parent(value, parent_stack, node)
201218

202219
else:
203220
normalized: object
204221
if (
205222
isinstance(node, ast.Constant)
206223
and field == "value"
207224
and isinstance(value, str)
225+
and len(parent_stack) >= 2
226+
and isinstance(parent_stack[-1], ast.Expr)
227+
and isinstance(
228+
parent_stack[-2],
229+
(ast.FunctionDef, ast.AsyncFunctionDef, ast.Module, ast.ClassDef),
230+
)
208231
):
209232
# Constant strings may be indented across newlines, if they are
210233
# docstrings; fold spaces after newlines when comparing. Similarly,
@@ -215,6 +238,9 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
215238
normalized = value.rstrip()
216239
else:
217240
normalized = value
218-
yield f"{' ' * (depth + 2)}{normalized!r}, # {value.__class__.__name__}"
241+
yield (
242+
f"{' ' * (len(parent_stack) + 1)}{normalized!r}, #"
243+
f" {value.__class__.__name__}"
244+
)
219245

220-
yield f"{' ' * depth}) # /{node.__class__.__name__}"
246+
yield f"{' ' * len(parent_stack)}) # /{node.__class__.__name__}"

tests/test_black.py

+108-14
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from black.debug import DebugVisitor
4747
from black.mode import Mode, Preview
4848
from black.output import color_diff, diff
49+
from black.parsing import ASTSafetyError
4950
from black.report import Report
5051

5152
# Import other test classes
@@ -1473,10 +1474,6 @@ def test_normalize_line_endings(self) -> None:
14731474
ff(test_file, write_back=black.WriteBack.YES)
14741475
self.assertEqual(test_file.read_bytes(), expected)
14751476

1476-
def test_assert_equivalent_different_asts(self) -> None:
1477-
with self.assertRaises(AssertionError):
1478-
black.assert_equivalent("{}", "None")
1479-
14801477
def test_root_logger_not_used_directly(self) -> None:
14811478
def fail(*args: Any, **kwargs: Any) -> None:
14821479
self.fail("Record created with root logger")
@@ -1962,16 +1959,6 @@ def test_for_handled_unexpected_eof_error(self) -> None:
19621959

19631960
exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
19641961

1965-
def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1966-
with pytest.raises(AssertionError) as err:
1967-
black.assert_equivalent("a«»a = 1", "a«»a = 1")
1968-
1969-
err.match("--safe")
1970-
# Unfortunately the SyntaxError message has changed in newer versions so we
1971-
# can't match it directly.
1972-
err.match("invalid character")
1973-
err.match(r"\(<unknown>, line 1\)")
1974-
19751962
def test_line_ranges_with_code_option(self) -> None:
19761963
code = textwrap.dedent("""\
19771964
if a == b:
@@ -2822,6 +2809,113 @@ def test_format_file_contents(self) -> None:
28222809
black.format_file_contents("x = 1\n", fast=True, mode=black.Mode())
28232810

28242811

2812+
class TestASTSafety(BlackBaseTestCase):
2813+
def check_ast_equivalence(
2814+
self, source: str, dest: str, *, should_fail: bool = False
2815+
) -> None:
2816+
# If we get a failure, make sure it's not because the code itself
2817+
# is invalid, since that will also cause assert_equivalent() to throw
2818+
# ASTSafetyError.
2819+
source = textwrap.dedent(source)
2820+
dest = textwrap.dedent(dest)
2821+
black.parse_ast(source)
2822+
black.parse_ast(dest)
2823+
if should_fail:
2824+
with self.assertRaises(ASTSafetyError):
2825+
black.assert_equivalent(source, dest)
2826+
else:
2827+
black.assert_equivalent(source, dest)
2828+
2829+
def test_assert_equivalent_basic(self) -> None:
2830+
self.check_ast_equivalence("{}", "None", should_fail=True)
2831+
self.check_ast_equivalence("1+2", "1 + 2")
2832+
self.check_ast_equivalence("hi # comment", "hi")
2833+
2834+
def test_assert_equivalent_del(self) -> None:
2835+
self.check_ast_equivalence("del (a, b)", "del a, b")
2836+
2837+
def test_assert_equivalent_strings(self) -> None:
2838+
self.check_ast_equivalence('x = "x"', 'x = " x "', should_fail=True)
2839+
self.check_ast_equivalence(
2840+
'''
2841+
"""docstring """
2842+
''',
2843+
'''
2844+
"""docstring"""
2845+
''',
2846+
)
2847+
self.check_ast_equivalence(
2848+
'''
2849+
"""docstring """
2850+
''',
2851+
'''
2852+
"""ddocstring"""
2853+
''',
2854+
should_fail=True,
2855+
)
2856+
self.check_ast_equivalence(
2857+
'''
2858+
class A:
2859+
"""
2860+
2861+
docstring
2862+
2863+
2864+
"""
2865+
''',
2866+
'''
2867+
class A:
2868+
"""docstring"""
2869+
''',
2870+
)
2871+
self.check_ast_equivalence(
2872+
"""
2873+
def f():
2874+
" docstring "
2875+
""",
2876+
'''
2877+
def f():
2878+
"""docstring"""
2879+
''',
2880+
)
2881+
self.check_ast_equivalence(
2882+
"""
2883+
async def f():
2884+
" docstring "
2885+
""",
2886+
'''
2887+
async def f():
2888+
"""docstring"""
2889+
''',
2890+
)
2891+
2892+
def test_assert_equivalent_fstring(self) -> None:
2893+
major, minor = sys.version_info[:2]
2894+
if major < 3 or (major == 3 and minor < 12):
2895+
pytest.skip("relies on 3.12+ syntax")
2896+
# https://github.com/psf/black/issues/4268
2897+
self.check_ast_equivalence(
2898+
"""print(f"{"|".join([a,b,c])}")""",
2899+
"""print(f"{" | ".join([a,b,c])}")""",
2900+
should_fail=True,
2901+
)
2902+
self.check_ast_equivalence(
2903+
"""print(f"{"|".join(['a','b','c'])}")""",
2904+
"""print(f"{" | ".join(['a','b','c'])}")""",
2905+
should_fail=True,
2906+
)
2907+
2908+
def test_equivalency_ast_parse_failure_includes_error(self) -> None:
2909+
with pytest.raises(ASTSafetyError) as err:
2910+
black.assert_equivalent("a«»a = 1", "a«»a = 1")
2911+
2912+
err.match("--safe")
2913+
# Unfortunately the SyntaxError message has changed in newer versions so we
2914+
# can't match it directly.
2915+
err.match("invalid character")
2916+
err.match(r"\(<unknown>, line 1\)")
2917+
2918+
28252919
try:
28262920
with open(black.__file__, "r", encoding="utf-8") as _bf:
28272921
black_source_lines = _bf.readlines()

0 commit comments

Comments
 (0)