diff --git a/rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py b/rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py index 28ca9b69e1..056bc4a20d 100644 --- a/rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py +++ b/rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py @@ -547,6 +547,8 @@ def visit_AnnAssign(self, node): def visit_For(self, node): prefix = self.__source_before('for') target = self.__convert(node.target) + control_prefix = target.prefix + target = target.replace(prefix=Space.EMPTY) # Wrap target in ExpressionStatement so it can be used as a Statement wrapped_target = py.ExpressionStatement(random_id(), target) in_prefix = self.__source_before('in') @@ -555,7 +557,7 @@ def visit_For(self, node): control = j.ForEachLoop.Control( random_id(), - Space.EMPTY, # No parentheses in Python, so no prefix space for control + control_prefix, Markers.EMPTY, self.__pad_right(wrapped_target, in_prefix), # Right padding has space before 'in' self.__pad_right(iterable, Space.EMPTY) # ':' comes from body's Block prefix @@ -589,15 +591,19 @@ def visit_AsyncFor(self, node): ) def visit_While(self, node): + while_prefix = self.__source_before('while') + condition = self.__convert(node.test) + ctrl_prefix = condition.prefix + condition = condition.replace(prefix=Space.EMPTY) while_ = j.WhileLoop( random_id(), - self.__source_before('while'), + while_prefix, Markers.EMPTY, j.ControlParentheses( random_id(), - Space.EMPTY, + ctrl_prefix, Markers.EMPTY, - self.__pad_right(self.__convert(node.test), Space.EMPTY) + self.__pad_right(condition, Space.EMPTY) ), self.__pad_right(self.__convert_block(node.body), Space.EMPTY) ) @@ -619,8 +625,11 @@ def visit_If(self, node): prefix = self.__source_before('elif') else: prefix = self.__source_before('if') - condition = j.ControlParentheses(random_id(), Space.EMPTY, Markers.EMPTY, - self.__pad_right(self.__convert(node.test), Space.EMPTY)) + cond_expr = self.__convert(node.test) + ctrl_prefix = cond_expr.prefix + cond_expr = cond_expr.replace(prefix=Space.EMPTY) + condition = j.ControlParentheses(random_id(), ctrl_prefix, Markers.EMPTY, + self.__pad_right(cond_expr, Space.EMPTY)) then = self.__pad_right(self.__convert_block(node.body), Space.EMPTY) elze = None if len(node.orelse) > 0: @@ -1118,15 +1127,19 @@ def visit_ExceptHandler(self, node, is_exception_group: bool = False): ) def visit_Match(self, node): + match_prefix = self.__source_before('match') + subject = self.__convert(node.subject) + ctrl_prefix = subject.prefix + subject = subject.replace(prefix=Space.EMPTY) return j.Switch( random_id(), - self.__source_before('match'), + match_prefix, Markers.EMPTY, j.ControlParentheses( random_id(), - Space.EMPTY, + ctrl_prefix, Markers.EMPTY, - self.__pad_right(self.__convert(node.subject), Space.EMPTY) + self.__pad_right(subject, Space.EMPTY) ), self.__convert_block(node.cases) ) @@ -1856,6 +1869,23 @@ def __convert_binary_operator(self, op) -> Union[JLeftPadded[j.Binary.Type], JLe raise ValueError(f"Unsupported operator: {op}") return self.__pad_left(self.__source_before(op_str), op) + @staticmethod + def _hoist_concat_prefix(node): + if not isinstance(node, py.Binary) or node.prefix != Space.EMPTY: + return node + leftmost = node + chain = [] + while isinstance(leftmost.left, py.Binary): + chain.append(leftmost) + leftmost = leftmost.left + if leftmost.left.prefix == Space.EMPTY: + return node + prefix = leftmost.left.prefix + leftmost = leftmost.replace(left=leftmost.left.replace(prefix=Space.EMPTY)) + for parent in reversed(chain): + leftmost = parent.replace(left=leftmost) + return leftmost.replace(prefix=prefix) + @staticmethod def _is_byte_string(tok_string: str) -> bool: """Check if a string token represents a byte string (has b/B in prefix).""" @@ -1927,7 +1957,7 @@ def visit_Constant(self, node): if idx >= len(self._tokens) - 1 and tok.type == token.ENDMARKER: break - return res + return self._hoist_concat_prefix(res) def __map_literal_simple(self, node): """Map a non-string constant (numbers, None, True, False, Ellipsis).""" @@ -2369,7 +2399,7 @@ def visit_JoinedStr(self, node): is_first = False - return res + return self._hoist_concat_prefix(res) def visit_TemplateStr(self, node): leading_prefix = self.__whitespace() @@ -2436,7 +2466,7 @@ def visit_TemplateStr(self, node): is_first = False - return res + return self._hoist_concat_prefix(res) def visit_FormattedValue(self, node): raise ValueError("This method should not be called directly") @@ -3360,7 +3390,7 @@ def __map_fstring_as_literal(self, node: ast.JoinedStr, leading_prefix: Space, t ) is_first = False assert res is not None - return res + return self._hoist_concat_prefix(res) def __map_fstring(self, node, prefix: Space, tok: TokenInfo, value_idx: int = 0, *, _start=None, _middle=None, _end=None) -> \ @@ -3482,7 +3512,7 @@ def __map_fstring(self, node, prefix: Space, tok: TokenInfo, value_idx: int = 0, self._type_mapping.type(node) ) - expr = self.__pad_right(nested, Space.EMPTY) + expr = self.__pad_right(self._hoist_concat_prefix(nested), Space.EMPTY) else: expr = self.__pad_right( self.__convert(value_inner), diff --git a/rewrite-python/rewrite/src/rewrite/python/format/spaces_visitor.py b/rewrite-python/rewrite/src/rewrite/python/format/spaces_visitor.py index 44b04a6a28..685e6c4be2 100644 --- a/rewrite-python/rewrite/src/rewrite/python/format/spaces_visitor.py +++ b/rewrite-python/rewrite/src/rewrite/python/format/spaces_visitor.py @@ -107,8 +107,13 @@ def visit_catch(self, catch: Try.Catch, p: P) -> J: def visit_control_parentheses(self, control_parens: ControlParentheses[J2], p: P) -> J: cp = cast(ControlParentheses[J2], super().visit_control_parentheses(control_parens, p)) - cp = space_before(cp, False) - cp = cp.padding.replace(tree=cp.padding.tree.replace(element=space_before(cp.tree, True))) + parent = self.cursor.parent_tree_cursor() + if parent is not None and isinstance(parent.value, Try.Catch): + cp = space_before(cp, False) + cp = cp.padding.replace(tree=cp.padding.tree.replace(element=space_before(cp.tree, True))) + else: + cp = space_before(cp, True) + cp = cp.padding.replace(tree=cp.padding.tree.replace(element=space_before(cp.tree, False))) return cp def visit_named_argument(self, named: NamedArgument, p: P) -> J: @@ -319,10 +324,10 @@ def visit_for_each_loop(self, for_each: j.ForEachLoop, p: P) -> J: control = fl.control # Set single space before loop target e.g. for i in...: <-> for i in ...: - var_rp = control.padding.variable - var_rp = var_rp.replace(element=space_before(var_rp.element, True)) + control = space_before(control, True) # Set single space before 'in' keyword + var_rp = control.padding.variable var_rp = space_after_right_padded(var_rp, True) # Set single space before loop iterable e.g. for i in []: <-> for i in []: diff --git a/rewrite-python/rewrite/tests/python/whitespace_attachment_test.py b/rewrite-python/rewrite/tests/python/whitespace_attachment_test.py new file mode 100644 index 0000000000..f49c6ed11b --- /dev/null +++ b/rewrite-python/rewrite/tests/python/whitespace_attachment_test.py @@ -0,0 +1,173 @@ +# Copyright 2026 the original author or authors. +#
+# Licensed under the Moderne Source Available License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +#
+# https://docs.moderne.io/licensing/moderne-source-available-license +#
+# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +from pathlib import Path +from typing import List, Optional, Union + +import pytest + + +from rewrite.java import J +from rewrite.python.support_types import Py +from rewrite.python.printer import PythonPrinter, PythonJavaPrinter, PrintOutputCapture +from rewrite.python._parser_visitor import ParserVisitor +from rewrite import Tree + + +def _prettify_type(tree: Tree) -> str: + cls = type(tree) + module = cls.__module__ + if 'python' in module: + return f"Py.{cls.__name__}" + elif 'java' in module: + return f"J.{cls.__name__}" + return cls.__name__ + + +class OutputNode: + def __init__(self, element: Tree): + self.element = element + self.children: List[Union['OutputNode', str]] = [] + + def add_child(self, child: Union['OutputNode', str]) -> None: + self.children.append(child) + + def __str__(self) -> str: + children_str = ', '.join( + f"Text({child})" if isinstance(child, str) else str(child) + for child in self.children + ) + return f"{_prettify_type(self.element)}{{{children_str}}}" + + +class TreeStructurePrintOutputCapture(PrintOutputCapture): + def __init__(self): + super().__init__() + self.root_nodes: List[OutputNode] = [] + self._node_stack: List[OutputNode] = [] + + def start_node(self, element: Tree) -> None: + node = OutputNode(element) + if self._node_stack: + self._node_stack[-1].add_child(node) + else: + self.root_nodes.append(node) + self._node_stack.append(node) + + def end_node(self) -> None: + self._node_stack.pop() + + def append(self, text: Optional[str]) -> 'TreeStructurePrintOutputCapture': + if text and len(text) > 0: + if self._node_stack: + self._node_stack[-1].add_child(text) + super().append(text) + return self + + +class TreeCapturingPythonJavaPrinter(PythonJavaPrinter): + def _before_syntax(self, tree: J, p: PrintOutputCapture) -> None: + if isinstance(p, TreeStructurePrintOutputCapture): + p.start_node(tree) + super()._before_syntax(tree, p) + + def _after_syntax(self, tree: J, p: PrintOutputCapture) -> None: + super()._after_syntax(tree, p) + if isinstance(p, TreeStructurePrintOutputCapture): + p.end_node() + + +class TreeCapturingPythonPrinter(PythonPrinter): + def __init__(self): + super().__init__() + self._delegate = TreeCapturingPythonJavaPrinter(self) + + def _before_syntax(self, tree: Union[Py, J], p: PrintOutputCapture) -> None: + if isinstance(p, TreeStructurePrintOutputCapture): + p.start_node(tree) + super()._before_syntax(tree, p) + + def _after_syntax(self, tree: Union[Py, J], p: PrintOutputCapture) -> None: + super()._after_syntax(tree, p) + if isinstance(p, TreeStructurePrintOutputCapture): + p.end_node() + + +def _find_whitespace_violations(root_nodes: List[OutputNode]) -> List[str]: + violations: List[str] = [] + + def check_node(node: OutputNode) -> None: + if node.children: + first_child = node.children[0] + if isinstance(first_child, OutputNode): + if first_child.children: + grandchild = first_child.children[0] + if isinstance(grandchild, str) and grandchild.strip() == '' and len(grandchild) > 0: + parent_kind = _prettify_type(node.element) + child_kind = _prettify_type(first_child.element) + violations.append( + f"{parent_kind} has child {child_kind} starting with whitespace " + f"|{grandchild}|. The whitespace should rather be attached to {parent_kind}." + ) + for child in node.children: + if isinstance(child, OutputNode): + check_node(child) + + for node in root_nodes: + check_node(node) + + return violations + + +def _parse_python(source: str) -> Tree: + source_path = Path("test.py") + visitor = ParserVisitor(source, None, None) + tree = ast.parse(source) + cu = visitor.visit_Module(tree) + return cu.replace(source_path=source_path) + + +@pytest.mark.parametrize("source", [ + pytest.param("x = 1", id="simple_assignment"), + pytest.param("def foo(x, y):\n return x + y", id="function_definition"), + pytest.param("d = {'x': 1, 'y': 2}", id="dict_literal"), + pytest.param("class Foo:\n def bar(self):\n pass", id="class_definition"), + pytest.param("from os.path import join", id="import_statement"), + pytest.param("result = [x * 2 for x in range(10) if x > 3]", id="list_comprehension"), + pytest.param("f = lambda x, y: x + y", id="lambda_expression"), + pytest.param("def greet(name):\n greeting = 'Hello, ' + name\n return greeting", id="multiline_function"), + pytest.param("try:\n x = 1\nexcept ValueError:\n x = 0", id="try_except"), + pytest.param("for i in range(10):\n print(i)", id="for_loop"), + pytest.param("if x > 0:\n pass", id="if_condition"), + pytest.param("while running:\n pass", id="while_loop"), + pytest.param("assert x > 0", id="assert_statement"), + pytest.param("x = a + b", id="binary_expression"), + pytest.param("@staticmethod\ndef foo():\n pass", id="decorator"), +]) +def test_whitespace_attachment(source): + # given + cu = _parse_python(source) + capture = TreeStructurePrintOutputCapture() + printer = TreeCapturingPythonPrinter() + + # when + printer.print(cu, capture) + + # then + assert capture.out == source + violations = _find_whitespace_violations(capture.root_nodes) + assert violations == [] + +