From 550c831d21ed4b659b5c6971c2ac0981cb03f453 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 20 Jun 2023 16:17:03 +0200 Subject: [PATCH 1/2] =?UTF-8?q?=E2=9C=A8=20Add=20multiprocessing=20logic?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bump_pydantic/main.py | 106 +++++++++++++++++++++++------------------- 1 file changed, 57 insertions(+), 49 deletions(-) diff --git a/bump_pydantic/main.py b/bump_pydantic/main.py index dc6a7f8..0a2694f 100644 --- a/bump_pydantic/main.py +++ b/bump_pydantic/main.py @@ -1,12 +1,13 @@ import difflib +import functools import multiprocessing import os import time from pathlib import Path -from typing import Any, Dict, Iterator +from typing import Any, Dict, List, Type import libcst as cst -from libcst.codemod import CodemodContext +from libcst.codemod import CodemodContext, ContextAwareTransformer from libcst.helpers import calculate_module_and_package from libcst.metadata import ( FullRepoManager, @@ -69,59 +70,66 @@ def main( codemods = gather_codemods() - # TODO: We can run this in parallel - batch it into files / cores. - with multiprocessing.Pool(): - cpu_count = multiprocessing.cpu_count() - batch_size = len(files) // cpu_count + 1 - - [files[i : i + batch_size] for i in range(0, len(files), batch_size)] - - for filename in files: - module_and_package = calculate_module_and_package(str(package), filename) - context = CodemodContext( - metadata_manager=metadata_manager, - filename=filename, - full_module_name=module_and_package.name, - full_package_name=module_and_package.package, - ) - context.scratch.update(scratch) - - file_path = Path(filename) - with file_path.open("r+") as fp: - code = fp.read() - fp.seek(0) - - input_code = str(code) - - for codemod in codemods: - transformer = codemod(context=context) - - input_tree = cst.parse_module(input_code) - output_tree = transformer.transform_module(input_tree) - - input_code = output_tree.code - - if code != input_code: - if diff: - color_diff( - console=console, - lines=difflib.unified_diff( - code.splitlines(keepends=True), - input_code.splitlines(keepends=True), - fromfile=filename, - tofile=filename, - ), - ) - else: - fp.write(input_code) - fp.truncate() + with multiprocessing.Pool() as pool: + partial_run_codemods = functools.partial(run_codemods, codemods, metadata_manager, scratch, package, diff) + for error_msg in pool.imap_unordered(partial_run_codemods, files): + if isinstance(error_msg, list): + color_diff(console, error_msg) modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time] if modified: print(f"Refactored {len(modified)} files.") -def color_diff(console: Console, lines: Iterator[str]) -> None: +def run_codemods( + codemods: List[Type[ContextAwareTransformer]], + metadata_manager: FullRepoManager, + scratch: Dict[str, Any], + package: Path, + diff: bool, + filename: str, +) -> List[str] | None: + module_and_package = calculate_module_and_package(str(package), filename) + context = CodemodContext( + metadata_manager=metadata_manager, + filename=filename, + full_module_name=module_and_package.name, + full_package_name=module_and_package.package, + ) + context.scratch.update(scratch) + + file_path = Path(filename) + with file_path.open("r+") as fp: + code = fp.read() + fp.seek(0) + + input_code = str(code) + + for codemod in codemods: + transformer = codemod(context=context) + + input_tree = cst.parse_module(input_code) + output_tree = transformer.transform_module(input_tree) + + input_code = output_tree.code + + if code != input_code: + if diff: + lines = difflib.unified_diff( + code.splitlines(keepends=True), + input_code.splitlines(keepends=True), + fromfile=filename, + tofile=filename, + ) + return list(lines) + else: + fp.write(input_code) + fp.truncate() + + return None + + +def color_diff(console: Console, lines: List[str]) -> None: for line in lines: line = line.rstrip("\n") if line.startswith("+"): From a25be664eb632e30b747bac8da4d313a774e3ca5 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 20 Jun 2023 16:21:57 +0200 Subject: [PATCH 2/2] Use future annotations --- bump_pydantic/main.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/bump_pydantic/main.py b/bump_pydantic/main.py index 0a2694f..50f125d 100644 --- a/bump_pydantic/main.py +++ b/bump_pydantic/main.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import difflib import functools import multiprocessing import os import time from pathlib import Path -from typing import Any, Dict, List, Type +from typing import Any import libcst as cst from libcst.codemod import CodemodContext, ContextAwareTransformer @@ -46,7 +48,7 @@ def main( metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type] metadata_manager.resolve_cache() - scratch: Dict[str, Any] = {} + scratch: dict[str, Any] = {} for filename in files: code = Path(filename).read_text() module = cst.parse_module(code) @@ -82,13 +84,13 @@ def main( def run_codemods( - codemods: List[Type[ContextAwareTransformer]], + codemods: list[type[ContextAwareTransformer]], metadata_manager: FullRepoManager, - scratch: Dict[str, Any], + scratch: dict[str, Any], package: Path, diff: bool, filename: str, -) -> List[str] | None: +) -> list[str] | None: module_and_package = calculate_module_and_package(str(package), filename) context = CodemodContext( metadata_manager=metadata_manager, @@ -129,7 +131,7 @@ def run_codemods( return None -def color_diff(console: Console, lines: List[str]) -> None: +def color_diff(console: Console, lines: list[str]) -> None: for line in lines: line = line.rstrip("\n") if line.startswith("+"):