diff --git a/bump_pydantic/main.py b/bump_pydantic/main.py index d387f02..3bdfe30 100644 --- a/bump_pydantic/main.py +++ b/bump_pydantic/main.py @@ -1,17 +1,17 @@ +import difflib import functools -import logging import multiprocessing import os import time import traceback from pathlib import Path -from typing import Any, Dict, List, Type, TypeVar, Union +from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union import libcst as cst from libcst.codemod import CodemodContext, ContextAwareTransformer from libcst.helpers import calculate_module_and_package from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider, ScopeProvider -from rich.logging import RichHandler +from rich.console import Console from rich.progress import Progress from typer import Argument, Exit, Option, Typer, echo from typing_extensions import ParamSpec @@ -30,10 +30,6 @@ T = TypeVar("T") -logging.basicConfig(level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]) -logger = logging.getLogger("bump_pydantic") - - def version_callback(value: bool): if value: echo(f"bump-pydantic version: {__version__}") @@ -44,6 +40,7 @@ def version_callback(value: bool): def main( path: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False), disable: List[Rule] = Option(default=[], help="Disable a rule."), + diff: bool = Option(False, help="Show diff instead of applying changes."), log_file: Path = Option("log.txt", help="Log errors to this file."), version: bool = Option( None, @@ -53,7 +50,8 @@ def main( help="Show the version and exit.", ), ): - logger.info("Start bump-pydantic.") + console = Console(log_time=True) + console.log("Start bump-pydantic.") # NOTE: LIBCST_PARSER_TYPE=native is required according to https://github.com/Instagram/LibCST/issues/487. os.environ["LIBCST_PARSER_TYPE"] = "native" @@ -65,42 +63,53 @@ def main( files_str = list(package.glob("**/*.py")) files = [str(file.relative_to(".")) for file in files_str] - logger.info(f"Found {len(files)} files to process.") + console.log(f"Found {len(files)} files to process") providers = {FullyQualifiedNameProvider, ScopeProvider} metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type] metadata_manager.resolve_cache() - logger.info("Running mypy to get type information. This may take a while...") + console.log("Running mypy to get type information. This may take a while...") classes = run_mypy_visitor(files) scratch: dict[str, Any] = {CONTEXT_KEY: classes} - logger.info("Finished mypy.") + console.log("Finished mypy.") start_time = time.time() codemods = gather_codemods(disabled=disable) log_fp = log_file.open("a+") - partial_run_codemods = functools.partial(run_codemods, codemods, metadata_manager, scratch, package) + partial_run_codemods = functools.partial(run_codemods, codemods, metadata_manager, scratch, package, diff) with Progress(*Progress.get_default_columns(), transient=True) as progress: task = progress.add_task(description="Executing codemods...", total=len(files)) count_errors = 0 + difflines: List[List[str]] = [] with multiprocessing.Pool() as pool: - for error in pool.imap_unordered(partial_run_codemods, files): + for error, _difflines in pool.imap_unordered(partial_run_codemods, files): progress.advance(task) + + if _difflines is not None: + difflines.append(_difflines) + if error is not None: count_errors += 1 log_fp.writelines(error) modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time] - if modified: - logger.info(f"Refactored {len(modified)} files.") + if modified and not diff: + console.log(f"Refactored {len(modified)} files.") + + for _difflines in difflines: + color_diff(console, _difflines) if count_errors > 0: - logger.info(f"Found {count_errors} errors. Please check the {log_file} file.") + console.log(f"Found {count_errors} errors. Please check the {log_file} file.") else: - logger.info("Run successfully!") + console.log("Run successfully!") + + if difflines: + raise Exit(1) def run_codemods( @@ -108,8 +117,9 @@ def run_codemods( metadata_manager: FullRepoManager, scratch: Dict[str, Any], package: Path, + diff: bool, filename: str, -) -> Union[str, None]: +) -> Tuple[Union[str, None], Union[List[str], None]]: try: module_and_package = calculate_module_and_package(str(package), filename) context = CodemodContext( @@ -134,8 +144,30 @@ def run_codemods( output_code = input_tree.code if code != output_code: - fp.write(output_code) - fp.truncate() - return None + if diff: + lines = difflib.unified_diff( + code.splitlines(keepends=True), + output_code.splitlines(keepends=True), + fromfile=filename, + tofile=filename, + ) + return None, list(lines) + else: + fp.write(output_code) + fp.truncate() + return None, None except Exception: - return f"An error happened on {filename}.\n{traceback.format_exc()}" + return f"An error happened on {filename}.\n{traceback.format_exc()}", None + + +def color_diff(console: Console, lines: Iterable[str]) -> None: + for line in lines: + line = line.rstrip("\n") + if line.startswith("+"): + console.print(line, style="green") + elif line.startswith("-"): + console.print(line, style="red") + elif line.startswith("^"): + console.print(line, style="blue") + else: + console.print(line, style="white")