From 2bd8177d78fd4b952f96980a59682977515ce06b Mon Sep 17 00:00:00 2001 From: devdanzin <74280297+devdanzin@users.noreply.github.com> Date: Sun, 3 Sep 2023 10:11:53 -0300 Subject: [PATCH] Add get_source_segment and support trailing comments in functions. --- ast_comments.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++-- test_parse.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/ast_comments.py b/ast_comments.py index 4dfb25d..f0f12bb 100644 --- a/ast_comments.py +++ b/ast_comments.py @@ -50,7 +50,7 @@ def _enrich(source: Union[str, bytes], tree: ast.AST) -> None: if not comment_nodes: return - tree_intervals = _get_tree_intervals(tree) + tree_intervals = _get_tree_intervals(source, tree) for c_node in comment_nodes: c_lineno = c_node.lineno possible_intervals_for_c_node = [ @@ -99,9 +99,11 @@ def _enrich(source: Union[str, bytes], tree: ast.AST) -> None: for left, right in zip(attr[:-1], attr[1:]): if isinstance(left, Comment) and isinstance(right, Comment): right.inline = False - + target_node.end_lineno = c_node.end_lineno + target_node.end_col_offset = c_node.end_col_offset def _get_tree_intervals( + source: str, node: ast.AST, ) -> Dict[Tuple[int, int], Dict[str, Union[List[Tuple[int, int]], ast.AST]]]: res = {} @@ -119,6 +121,12 @@ def _get_tree_intervals( if hasattr(node, "end_lineno") else max(attr_intervals)[1] ) + # Add trailing comment lines, doesn't match indentation + for line in source.splitlines()[high:]: + if line.strip().startswith("#"): + high += 1 + else: + break res[(low, high)] = {"intervals": attr_intervals, "node": node} return res @@ -173,3 +181,44 @@ def _get_first_not_comment_idx(orelse: list[ast.stmt]) -> int: def unparse(ast_obj: ast.AST) -> str: return _Unparser().visit(ast_obj) + + +def get_source_segment(source, node, *, padded=False): + """Get source code segment of the *source* that generated *node*. + + If some location information (`lineno`, `end_lineno`, `col_offset`, + or `end_col_offset`) is missing, return None. + + If *padded* is `True`, the first line of a multi-line statement will + be padded with spaces to match its original position. + + Customized version of ast.get_source_segment that includes trailing + inline comments. + """ + try: + if node.end_lineno is None or node.end_col_offset is None: + return None + lineno = node.lineno - 1 + end_lineno = node.end_lineno - 1 + col_offset = node.col_offset + # Add trailing inline comment: + end_col_offset = max(node.body[-1].end_col_offset, node.end_col_offset) + except AttributeError: + return None + + lines = ast._splitlines_no_ff(source) + if end_lineno == lineno: + return lines[lineno].encode()[col_offset:end_col_offset].decode() + + if padded: + padding = ast._pad_whitespace(lines[lineno].encode()[:col_offset].decode()) + else: + padding = '' + + first = padding + lines[lineno].encode()[col_offset:].decode() + last = lines[end_lineno].encode()[:end_col_offset].decode() + lines = lines[lineno+1:end_lineno] + + lines.insert(0, first) + lines.append(last) + return ''.join(lines) diff --git a/test_parse.py b/test_parse.py index 9422cbe..dfdeecb 100644 --- a/test_parse.py +++ b/test_parse.py @@ -5,7 +5,7 @@ import pytest -from ast_comments import Comment, parse +from ast_comments import Comment, get_source_segment, parse def test_single_comment_in_tree(): @@ -304,3 +304,54 @@ def test_comment_in_multilined_list(): """ ) assert len(parse(source).body) == 1 + + +def test_function_with_trailing_comment(): + """Function with trailing comments inside.""" + source = dedent( + """ + def foo(*args, **kwargs): + print(args, kwargs) # comment to print + # A comment + # comment in function 'foo' + """ + ) + nodes = parse(source).body + assert len(nodes) == 1 + function_node = nodes[0] + assert function_node.body[1].value == "# comment to print" + assert function_node.body[1].inline + assert function_node.body[-1].value == "# comment in function 'foo'" + assert not function_node.body[-1].inline + + +def test_get_source_segment(): + """Check that get_source_segment roundtrips function code.""" + source = dedent( + """ + def foo(*args, **kwargs): + print(args, kwargs) # comment to print + # A comment + # comment in function 'foo' + """ + ) + function_node = parse(source).body[0] + assert source.strip() == get_source_segment(source, function_node) + + +@pytest.mark.xfail(reason="Skipping extraneous comments doesn't work.") +def test_get_source_segment_outside_comment(): + """Check that get_source_segment skips extraneous comments.""" + source = dedent( + """ + def foo(*args, **kwargs): + print(args, kwargs) # comment to print + # A comment + # comment in function 'foo' + # comment outside function 'foo' + """ + ) + function_node = parse(source).body[0] + assert function_node.body[-1].value == "# comment in function 'foo'" + assert not function_node.body[-1].inline + assert source.strip() == get_source_segment(source, function_node)