diff --git a/bump_pydantic/main.py b/bump_pydantic/main.py index dc6a7f8..50f125d 100644 --- a/bump_pydantic/main.py +++ b/bump_pydantic/main.py @@ -1,12 +1,15 @@ +from __future__ import annotations + import difflib +import functools import multiprocessing import os import time from pathlib import Path -from typing import Any, Dict, Iterator +from typing import Any import libcst as cst -from libcst.codemod import CodemodContext +from libcst.codemod import CodemodContext, ContextAwareTransformer from libcst.helpers import calculate_module_and_package from libcst.metadata import ( FullRepoManager, @@ -45,7 +48,7 @@ def main( metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type] metadata_manager.resolve_cache() - scratch: Dict[str, Any] = {} + scratch: dict[str, Any] = {} for filename in files: code = Path(filename).read_text() module = cst.parse_module(code) @@ -69,59 +72,66 @@ def main( codemods = gather_codemods() - # TODO: We can run this in parallel - batch it into files / cores. - with multiprocessing.Pool(): - cpu_count = multiprocessing.cpu_count() - batch_size = len(files) // cpu_count + 1 - - [files[i : i + batch_size] for i in range(0, len(files), batch_size)] - - for filename in files: - module_and_package = calculate_module_and_package(str(package), filename) - context = CodemodContext( - metadata_manager=metadata_manager, - filename=filename, - full_module_name=module_and_package.name, - full_package_name=module_and_package.package, - ) - context.scratch.update(scratch) - - file_path = Path(filename) - with file_path.open("r+") as fp: - code = fp.read() - fp.seek(0) - - input_code = str(code) - - for codemod in codemods: - transformer = codemod(context=context) - - input_tree = cst.parse_module(input_code) - output_tree = transformer.transform_module(input_tree) - - input_code = output_tree.code - - if code != input_code: - if diff: - color_diff( - console=console, - lines=difflib.unified_diff( - code.splitlines(keepends=True), - input_code.splitlines(keepends=True), - fromfile=filename, - tofile=filename, - ), - ) - else: - fp.write(input_code) - fp.truncate() + with multiprocessing.Pool() as pool: + partial_run_codemods = functools.partial(run_codemods, codemods, metadata_manager, scratch, package, diff) + for error_msg in pool.imap_unordered(partial_run_codemods, files): + if isinstance(error_msg, list): + color_diff(console, error_msg) modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time] if modified: print(f"Refactored {len(modified)} files.") -def color_diff(console: Console, lines: Iterator[str]) -> None: +def run_codemods( + codemods: list[type[ContextAwareTransformer]], + metadata_manager: FullRepoManager, + scratch: dict[str, Any], + package: Path, + diff: bool, + filename: str, +) -> list[str] | None: + module_and_package = calculate_module_and_package(str(package), filename) + context = CodemodContext( + metadata_manager=metadata_manager, + filename=filename, + full_module_name=module_and_package.name, + full_package_name=module_and_package.package, + ) + context.scratch.update(scratch) + + file_path = Path(filename) + with file_path.open("r+") as fp: + code = fp.read() + fp.seek(0) + + input_code = str(code) + + for codemod in codemods: + transformer = codemod(context=context) + + input_tree = cst.parse_module(input_code) + output_tree = transformer.transform_module(input_tree) + + input_code = output_tree.code + + if code != input_code: + if diff: + lines = difflib.unified_diff( + code.splitlines(keepends=True), + input_code.splitlines(keepends=True), + fromfile=filename, + tofile=filename, + ) + return list(lines) + else: + fp.write(input_code) + fp.truncate() + + return None + + +def color_diff(console: Console, lines: list[str]) -> None: for line in lines: line = line.rstrip("\n") if line.startswith("+"):