Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,8 @@ def visit_AnnAssign(self, node):
def visit_For(self, node):
prefix = self.__source_before('for')
target = self.__convert(node.target)
control_prefix = target.prefix
target = target.replace(prefix=Space.EMPTY)
# Wrap target in ExpressionStatement so it can be used as a Statement
wrapped_target = py.ExpressionStatement(random_id(), target)
in_prefix = self.__source_before('in')
Expand All @@ -555,7 +557,7 @@ def visit_For(self, node):

control = j.ForEachLoop.Control(
random_id(),
Space.EMPTY, # No parentheses in Python, so no prefix space for control
control_prefix,
Markers.EMPTY,
self.__pad_right(wrapped_target, in_prefix), # Right padding has space before 'in'
self.__pad_right(iterable, Space.EMPTY) # ':' comes from body's Block prefix
Expand Down Expand Up @@ -589,15 +591,19 @@ def visit_AsyncFor(self, node):
)

def visit_While(self, node):
while_prefix = self.__source_before('while')
condition = self.__convert(node.test)
ctrl_prefix = condition.prefix
condition = condition.replace(prefix=Space.EMPTY)
while_ = j.WhileLoop(
random_id(),
self.__source_before('while'),
while_prefix,
Markers.EMPTY,
j.ControlParentheses(
random_id(),
Space.EMPTY,
ctrl_prefix,
Markers.EMPTY,
self.__pad_right(self.__convert(node.test), Space.EMPTY)
self.__pad_right(condition, Space.EMPTY)
),
self.__pad_right(self.__convert_block(node.body), Space.EMPTY)
)
Expand All @@ -619,8 +625,11 @@ def visit_If(self, node):
prefix = self.__source_before('elif')
else:
prefix = self.__source_before('if')
condition = j.ControlParentheses(random_id(), Space.EMPTY, Markers.EMPTY,
self.__pad_right(self.__convert(node.test), Space.EMPTY))
cond_expr = self.__convert(node.test)
ctrl_prefix = cond_expr.prefix
cond_expr = cond_expr.replace(prefix=Space.EMPTY)
condition = j.ControlParentheses(random_id(), ctrl_prefix, Markers.EMPTY,
self.__pad_right(cond_expr, Space.EMPTY))
then = self.__pad_right(self.__convert_block(node.body), Space.EMPTY)
elze = None
if len(node.orelse) > 0:
Expand Down Expand Up @@ -1118,15 +1127,19 @@ def visit_ExceptHandler(self, node, is_exception_group: bool = False):
)

def visit_Match(self, node):
match_prefix = self.__source_before('match')
subject = self.__convert(node.subject)
ctrl_prefix = subject.prefix
subject = subject.replace(prefix=Space.EMPTY)
return j.Switch(
random_id(),
self.__source_before('match'),
match_prefix,
Markers.EMPTY,
j.ControlParentheses(
random_id(),
Space.EMPTY,
ctrl_prefix,
Markers.EMPTY,
self.__pad_right(self.__convert(node.subject), Space.EMPTY)
self.__pad_right(subject, Space.EMPTY)
),
self.__convert_block(node.cases)
)
Expand Down Expand Up @@ -1856,6 +1869,23 @@ def __convert_binary_operator(self, op) -> Union[JLeftPadded[j.Binary.Type], JLe
raise ValueError(f"Unsupported operator: {op}")
return self.__pad_left(self.__source_before(op_str), op)

@staticmethod
def _hoist_concat_prefix(node):
if not isinstance(node, py.Binary) or node.prefix != Space.EMPTY:
return node
leftmost = node
chain = []
while isinstance(leftmost.left, py.Binary):
chain.append(leftmost)
leftmost = leftmost.left
if leftmost.left.prefix == Space.EMPTY:
return node
prefix = leftmost.left.prefix
leftmost = leftmost.replace(left=leftmost.left.replace(prefix=Space.EMPTY))
for parent in reversed(chain):
leftmost = parent.replace(left=leftmost)
return leftmost.replace(prefix=prefix)

@staticmethod
def _is_byte_string(tok_string: str) -> bool:
"""Check if a string token represents a byte string (has b/B in prefix)."""
Expand Down Expand Up @@ -1927,7 +1957,7 @@ def visit_Constant(self, node):
if idx >= len(self._tokens) - 1 and tok.type == token.ENDMARKER:
break

return res
return self._hoist_concat_prefix(res)

def __map_literal_simple(self, node):
"""Map a non-string constant (numbers, None, True, False, Ellipsis)."""
Expand Down Expand Up @@ -2369,7 +2399,7 @@ def visit_JoinedStr(self, node):

is_first = False

return res
return self._hoist_concat_prefix(res)

def visit_TemplateStr(self, node):
leading_prefix = self.__whitespace()
Expand Down Expand Up @@ -2436,7 +2466,7 @@ def visit_TemplateStr(self, node):

is_first = False

return res
return self._hoist_concat_prefix(res)

def visit_FormattedValue(self, node):
raise ValueError("This method should not be called directly")
Expand Down Expand Up @@ -3360,7 +3390,7 @@ def __map_fstring_as_literal(self, node: ast.JoinedStr, leading_prefix: Space, t
)
is_first = False
assert res is not None
return res
return self._hoist_concat_prefix(res)

def __map_fstring(self, node, prefix: Space, tok: TokenInfo, value_idx: int = 0, *,
_start=None, _middle=None, _end=None) -> \
Expand Down Expand Up @@ -3482,7 +3512,7 @@ def __map_fstring(self, node, prefix: Space, tok: TokenInfo, value_idx: int = 0,
self._type_mapping.type(node)
)

expr = self.__pad_right(nested, Space.EMPTY)
expr = self.__pad_right(self._hoist_concat_prefix(nested), Space.EMPTY)
else:
expr = self.__pad_right(
self.__convert(value_inner),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,13 @@ def visit_catch(self, catch: Try.Catch, p: P) -> J:

def visit_control_parentheses(self, control_parens: ControlParentheses[J2], p: P) -> J:
cp = cast(ControlParentheses[J2], super().visit_control_parentheses(control_parens, p))
cp = space_before(cp, False)
cp = cp.padding.replace(tree=cp.padding.tree.replace(element=space_before(cp.tree, True)))
parent = self.cursor.parent_tree_cursor()
if parent is not None and isinstance(parent.value, Try.Catch):
cp = space_before(cp, False)
cp = cp.padding.replace(tree=cp.padding.tree.replace(element=space_before(cp.tree, True)))
else:
cp = space_before(cp, True)
cp = cp.padding.replace(tree=cp.padding.tree.replace(element=space_before(cp.tree, False)))
return cp

def visit_named_argument(self, named: NamedArgument, p: P) -> J:
Expand Down Expand Up @@ -319,10 +324,10 @@ def visit_for_each_loop(self, for_each: j.ForEachLoop, p: P) -> J:
control = fl.control

# Set single space before loop target e.g. for i in...: <-> for i in ...:
var_rp = control.padding.variable
var_rp = var_rp.replace(element=space_before(var_rp.element, True))
control = space_before(control, True)

# Set single space before 'in' keyword
var_rp = control.padding.variable
var_rp = space_after_right_padded(var_rp, True)

# Set single space before loop iterable e.g. for i in []: <-> for i in []:
Expand Down
173 changes: 173 additions & 0 deletions rewrite-python/rewrite/tests/python/whitespace_attachment_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2026 the original author or authors.
# <p>
# Licensed under the Moderne Source Available License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# <p>
# https://docs.moderne.io/licensing/moderne-source-available-license
# <p>
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
from pathlib import Path
from typing import List, Optional, Union

import pytest


from rewrite.java import J
from rewrite.python.support_types import Py
from rewrite.python.printer import PythonPrinter, PythonJavaPrinter, PrintOutputCapture
from rewrite.python._parser_visitor import ParserVisitor
from rewrite import Tree


def _prettify_type(tree: Tree) -> str:
cls = type(tree)
module = cls.__module__
if 'python' in module:
return f"Py.{cls.__name__}"
elif 'java' in module:
return f"J.{cls.__name__}"
return cls.__name__


class OutputNode:
def __init__(self, element: Tree):
self.element = element
self.children: List[Union['OutputNode', str]] = []

def add_child(self, child: Union['OutputNode', str]) -> None:
self.children.append(child)

def __str__(self) -> str:
children_str = ', '.join(
f"Text({child})" if isinstance(child, str) else str(child)
for child in self.children
)
return f"{_prettify_type(self.element)}{{{children_str}}}"


class TreeStructurePrintOutputCapture(PrintOutputCapture):
def __init__(self):
super().__init__()
self.root_nodes: List[OutputNode] = []
self._node_stack: List[OutputNode] = []

def start_node(self, element: Tree) -> None:
node = OutputNode(element)
if self._node_stack:
self._node_stack[-1].add_child(node)
else:
self.root_nodes.append(node)
self._node_stack.append(node)

def end_node(self) -> None:
self._node_stack.pop()

def append(self, text: Optional[str]) -> 'TreeStructurePrintOutputCapture':
if text and len(text) > 0:
if self._node_stack:
self._node_stack[-1].add_child(text)
super().append(text)
return self


class TreeCapturingPythonJavaPrinter(PythonJavaPrinter):
def _before_syntax(self, tree: J, p: PrintOutputCapture) -> None:
if isinstance(p, TreeStructurePrintOutputCapture):
p.start_node(tree)
super()._before_syntax(tree, p)

def _after_syntax(self, tree: J, p: PrintOutputCapture) -> None:
super()._after_syntax(tree, p)
if isinstance(p, TreeStructurePrintOutputCapture):
p.end_node()


class TreeCapturingPythonPrinter(PythonPrinter):
def __init__(self):
super().__init__()
self._delegate = TreeCapturingPythonJavaPrinter(self)

def _before_syntax(self, tree: Union[Py, J], p: PrintOutputCapture) -> None:
if isinstance(p, TreeStructurePrintOutputCapture):
p.start_node(tree)
super()._before_syntax(tree, p)

def _after_syntax(self, tree: Union[Py, J], p: PrintOutputCapture) -> None:
super()._after_syntax(tree, p)
if isinstance(p, TreeStructurePrintOutputCapture):
p.end_node()


def _find_whitespace_violations(root_nodes: List[OutputNode]) -> List[str]:
violations: List[str] = []

def check_node(node: OutputNode) -> None:
if node.children:
first_child = node.children[0]
if isinstance(first_child, OutputNode):
if first_child.children:
grandchild = first_child.children[0]
if isinstance(grandchild, str) and grandchild.strip() == '' and len(grandchild) > 0:
parent_kind = _prettify_type(node.element)
child_kind = _prettify_type(first_child.element)
violations.append(
f"{parent_kind} has child {child_kind} starting with whitespace "
f"|{grandchild}|. The whitespace should rather be attached to {parent_kind}."
)
for child in node.children:
if isinstance(child, OutputNode):
check_node(child)

for node in root_nodes:
check_node(node)

return violations


def _parse_python(source: str) -> Tree:
source_path = Path("test.py")
visitor = ParserVisitor(source, None, None)
tree = ast.parse(source)
cu = visitor.visit_Module(tree)
return cu.replace(source_path=source_path)


@pytest.mark.parametrize("source", [
pytest.param("x = 1", id="simple_assignment"),
pytest.param("def foo(x, y):\n return x + y", id="function_definition"),
pytest.param("d = {'x': 1, 'y': 2}", id="dict_literal"),
pytest.param("class Foo:\n def bar(self):\n pass", id="class_definition"),
pytest.param("from os.path import join", id="import_statement"),
pytest.param("result = [x * 2 for x in range(10) if x > 3]", id="list_comprehension"),
pytest.param("f = lambda x, y: x + y", id="lambda_expression"),
pytest.param("def greet(name):\n greeting = 'Hello, ' + name\n return greeting", id="multiline_function"),
pytest.param("try:\n x = 1\nexcept ValueError:\n x = 0", id="try_except"),
pytest.param("for i in range(10):\n print(i)", id="for_loop"),
pytest.param("if x > 0:\n pass", id="if_condition"),
pytest.param("while running:\n pass", id="while_loop"),
pytest.param("assert x > 0", id="assert_statement"),
pytest.param("x = a + b", id="binary_expression"),
pytest.param("@staticmethod\ndef foo():\n pass", id="decorator"),
])
def test_whitespace_attachment(source):
# given
cu = _parse_python(source)
capture = TreeStructurePrintOutputCapture()
printer = TreeCapturingPythonPrinter()

# when
printer.print(cu, capture)

# then
assert capture.out == source
violations = _find_whitespace_violations(capture.root_nodes)
assert violations == []