Skip to content

Commit

Permalink
skip redundant ast->str->ast conversion; contain concat and join exce…
Browse files Browse the repository at this point in the history
…ptions
  • Loading branch information
ikamensh committed Jun 17, 2023
1 parent 0f685d2 commit 7558aeb
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 82 deletions.
37 changes: 25 additions & 12 deletions src/flynt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import sys
import time
import traceback
from difflib import unified_diff
from typing import Collection, List, Optional, Tuple

Expand Down Expand Up @@ -94,19 +95,31 @@ def fstringify_code(
state=state,
)
if state.transform_concat:
new_code, concat_changes = fstringify_concats(
new_code,
state=state,
)
changes += concat_changes
state.concat_changes += concat_changes
try:
new_code, concat_changes = fstringify_concats(
new_code,
state=state,
)
except Exception:
msg = traceback.format_exc()
log.error("Transforming concatenation of literal strings failed")
log.error(msg)
else:
changes += concat_changes
state.concat_changes += concat_changes
if state.transform_join:
new_code, join_changes = fstringify_static_joins(
new_code,
state=state,
)
changes += join_changes
state.join_changes += join_changes
try:
new_code, join_changes = fstringify_static_joins(
new_code,
state=state,
)
except Exception:
msg = traceback.format_exc()
log.error("Transforming concatenation of literal strings failed")
log.error(msg)
else:
changes += join_changes
state.join_changes += join_changes

except Exception as e:
msg = str(e) or e.__class__.__name__
Expand Down
2 changes: 1 addition & 1 deletion src/flynt/code_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def try_chunk(self, chunk: AstChunk) -> None:
except FlyntException:
quote_type = qt.double

converted, changed = self.transform_func(str(chunk), quote_type=quote_type)
converted, changed = self.transform_func(chunk.node, quote_type=quote_type)
if changed:
contract_lines = chunk.n_lines - 1
if contract_lines == 0:
Expand Down
3 changes: 1 addition & 2 deletions src/flynt/static_join/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def visit_Call(self, node: ast.Call):
return ast.JoinedStr(args_with_interleaved_joiner)


def transform_join(code: str, *args, **kwargs) -> Tuple[str, bool]:
tree = ast.parse(f"({code})")
def transform_join(tree: ast.AST, *args, **kwargs) -> Tuple[str, bool]:

jt = JoinTransformer()
jt.visit(tree)
Expand Down
7 changes: 3 additions & 4 deletions src/flynt/string_concat/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,10 @@ def visit_BinOp(self, node: ast.BinOp) -> ast.AST:
return ast.JoinedStr(segments)


def transform_concat(code: str, *args, **kwargs) -> Tuple[str, bool]:
tree = ast.parse(f"({code})")
def transform_concat(tree: ast.AST, *args, **kwargs) -> Tuple[str, bool]:

ft = ConcatTransformer()
ft.visit(tree)
new_code = fixup_transformed(tree)
new = ft.visit(tree)
new_code = fixup_transformed(new)

return new_code, ft.counter > 0
23 changes: 12 additions & 11 deletions src/flynt/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,34 @@


def transform_chunk(
code: str,
tree: ast.AST,
state: State,
quote_type: str = QuoteTypes.triple_double,
) -> Tuple[str, bool]:
"""Convert a block of code to an f-string
Args:
tree: The code to convert as AST.
state: State object, for settings and statistics
code: The code to convert.
quote_type: the quote type to use for the transformed result
Returns:
Tuple: resulting code, boolean: was it changed?
"""
try:
tree = ast.parse(code)
converted, changed = fstringify_node(
copy.deepcopy(tree),
state=state,
)
str_in_str = str_in_str_fn(converted)
except ConversionRefused as cr:
log.warning("Not converting code '%s': %s", code, cr)
log.warning("Not converting code due to: %s", cr)
state.invalid_conversions += 1
return code, False
return None, False # type:ignore # ideally should return one optional str
except Exception:
msg = traceback.format_exc()
log.exception("Exception during conversion of code '%s': %s", code, msg)
log.exception("Exception during conversion of code: %s", msg)
state.invalid_conversions += 1
return code, False
return None, False # type:ignore # ideally should return one optional str
else:
if changed:
if str_in_str and quote_type == QuoteTypes.single:
Expand All @@ -53,14 +52,16 @@ def transform_chunk(
ast.parse(new_code)
except SyntaxError:
log.error(
"Failed to parse transformed code '%s'' given original '%s'",
"Failed to parse transformed code '%s'",
new_code,
code,
exc_info=True,
)
state.invalid_conversions += 1
return code, False
return (
None,
False,
) # type:ignore # ideally should return one optional str
else:
return new_code, changed

return code, False
return None, False # type:ignore # ideally should return one optional str
2 changes: 0 additions & 2 deletions test/integration/test_concat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
""" Test str processors on actual file contents """
import sys
from test.integration.utils import concat_samples, try_on_file

import pytest
Expand All @@ -15,7 +14,6 @@ def fstringify_and_concats(code: str):
return code, count_a + count_b


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
@pytest.mark.parametrize("filename_concat", concat_samples)
def test_fstringify_concat(filename_concat):
out, expected = try_on_file(
Expand Down
3 changes: 2 additions & 1 deletion test/test_static_join/test_sj_transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import sys
from test.test_static_join.utils import CASES
from typing import Optional
Expand All @@ -13,7 +14,7 @@

@pytest.mark.parametrize("source, expected", CASES)
def test_transform(source: str, expected: Optional[str]):
new, changed = transform_join(source)
new, changed = transform_join(ast.parse(source))
if changed:
assert new == expected
else:
Expand Down
56 changes: 29 additions & 27 deletions test/test_str_concat/test_transformer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import ast
import sys

import pytest

from flynt.state import State
from flynt.string_concat.transformer import transform_concat, unpack_binop


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def transform_concat_from_str(code: str, state=State()):
tree = ast.parse(code)
return transform_concat(tree, state)


def test_unpack():

txt = """a + 'Hello' + b + 'World'"""
Expand All @@ -29,144 +32,143 @@ def test_unpack():
assert seq[3].value == "World"


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def test_transform():

txt = """a + 'Hello' + b + 'World'"""
expected = '''f"{a}Hello{b}World"'''

new, changed = transform_concat(txt)
new, changed = transform_concat_from_str(txt)

assert changed
assert new == expected


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def test_transform_nonatomic():

txt = """'blah' + (thing - 1)"""
expected = '''f"blah{thing - 1}"'''

new, changed = transform_concat(txt)
new, changed = transform_concat_from_str(txt)

assert changed
assert new == expected


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def test_transform_attribute():

txt = """'blah' + blah.blah"""
expected = '''f"blah{blah.blah}"'''

new, changed = transform_concat(txt)
new, changed = transform_concat_from_str(txt)

assert changed
assert new == expected


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def test_transform_complex():

txt = """'blah' + lst[123].process(x, y, z) + 'Yeah'"""
expected = '''f"blah{lst[123].process(x, y, z)}Yeah"'''

new, changed = transform_concat(txt)
new, changed = transform_concat_from_str(txt)

assert changed
assert new == expected


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def test_string_in_string():

txt = """'blah' + blah.blah('more' + vars)"""
expected = '''f"blah{blah.blah(f'more{vars}')}"'''

new, changed = transform_concat(txt)
new, changed = transform_concat_from_str(txt)

assert changed
assert new == expected


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def test_concats_fstring():

txt = """print(f'blah{thing}' + 'blah' + otherThing + f"is {x:d}")"""
expected = """print(f'blah{thing}blah{otherThing}is {x:d}')"""

new, changed = transform_concat(txt)
new, changed = transform_concat_from_str(txt)

assert changed
assert new == expected


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def test_string_in_string_x3():

txt = """'blah' + blah.blah('more' + vars.foo('other' + b))"""

new, changed = transform_concat(txt)
new, changed = transform_concat_from_str(txt)

assert changed
assert "'blah' +" in new


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def test_existing_fstr():

txt = """f'blah{thing}' + otherThing + 'blah'"""
expected = '''f"blah{thing}{otherThing}blah"'''

new, changed = transform_concat(txt)
new, changed = transform_concat_from_str(txt)

assert changed
assert new == expected


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def test_existing_fstr_expr():

txt = """f'blah{thing}' + otherThing + f'blah{thing + 1}'"""
expected = '''f"blah{thing}{otherThing}blah{thing + 1}"'''

new, changed = transform_concat(txt)
new, changed = transform_concat_from_str(txt)

assert changed
assert new == expected


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def test_embedded_fstr():

txt = """print(f"{f'blah{var}' + abc}blah")"""
expected = """print(f'blah{var}{abc}blah')"""

new, changed = transform_concat(txt)
new, changed = transform_concat_from_str(txt)

assert changed
assert new == expected


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def test_backslash():
txt = """blah1 \
+ 'b'"""

expected = '''f"{blah1}b"'''
new, changed = transform_concat(txt)
new, changed = transform_concat_from_str(txt)

assert changed
assert new == expected


@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def test_parens():
txt = """(blah1
+ 'b')"""

expected = '''f"{blah1}b"'''
new, changed = transform_concat(txt)
new, changed = transform_concat_from_str(txt)

assert changed
assert new == expected


noexc_in = """individual_tests = [re.sub(r"\.py$", "", test) + ".py" for test in tests if not test.endswith('*')]"""
noexc_out = """individual_tests = [f"{re.sub(r"\.py$", "", test)}.py" for test in tests if not test.endswith('*')]"""


def test_noexc():
new, changed = transform_concat_from_str(noexc_in)
# TODO this doesn't produce expected output - number of escapes is surprising
# assert changed
# assert new == noexc_out
Loading

0 comments on commit 7558aeb

Please sign in to comment.