Skip to content

Commit

Permalink
✨ Add --diff argument
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Jul 17, 2023
1 parent 3bd8332 commit 7b83d4b
Showing 1 changed file with 54 additions and 22 deletions.
76 changes: 54 additions & 22 deletions bump_pydantic/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import difflib
import functools
import logging
import multiprocessing
import os
import time
import traceback
from pathlib import Path
from typing import Any, Dict, List, Type, TypeVar, Union
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union

import libcst as cst
from libcst.codemod import CodemodContext, ContextAwareTransformer
from libcst.helpers import calculate_module_and_package
from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider, ScopeProvider
from rich.logging import RichHandler
from rich.console import Console
from rich.progress import Progress
from typer import Argument, Exit, Option, Typer, echo
from typing_extensions import ParamSpec
Expand All @@ -30,10 +30,6 @@
T = TypeVar("T")


logging.basicConfig(level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()])
logger = logging.getLogger("bump_pydantic")


def version_callback(value: bool):
if value:
echo(f"bump-pydantic version: {__version__}")
Expand All @@ -44,6 +40,7 @@ def version_callback(value: bool):
def main(
path: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False),
disable: List[Rule] = Option(default=[], help="Disable a rule."),
diff: bool = Option(False, help="Show diff instead of applying changes."),
log_file: Path = Option("log.txt", help="Log errors to this file."),
version: bool = Option(
None,
Expand All @@ -53,7 +50,8 @@ def main(
help="Show the version and exit.",
),
):
logger.info("Start bump-pydantic.")
console = Console(log_time=True)
console.log("Start bump-pydantic.")
# NOTE: LIBCST_PARSER_TYPE=native is required according to https://github.com/Instagram/LibCST/issues/487.
os.environ["LIBCST_PARSER_TYPE"] = "native"

Expand All @@ -65,51 +63,63 @@ def main(
files_str = list(package.glob("**/*.py"))
files = [str(file.relative_to(".")) for file in files_str]

logger.info(f"Found {len(files)} files to process.")
console.log(f"Found {len(files)} files to process")

providers = {FullyQualifiedNameProvider, ScopeProvider}
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type]
metadata_manager.resolve_cache()

logger.info("Running mypy to get type information. This may take a while...")
console.log("Running mypy to get type information. This may take a while...")
classes = run_mypy_visitor(files)
scratch: dict[str, Any] = {CONTEXT_KEY: classes}
logger.info("Finished mypy.")
console.log("Finished mypy.")

start_time = time.time()

codemods = gather_codemods(disabled=disable)

log_fp = log_file.open("a+")
partial_run_codemods = functools.partial(run_codemods, codemods, metadata_manager, scratch, package)
partial_run_codemods = functools.partial(run_codemods, codemods, metadata_manager, scratch, package, diff)
with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Executing codemods...", total=len(files))
count_errors = 0
difflines: List[List[str]] = []
with multiprocessing.Pool() as pool:
for error in pool.imap_unordered(partial_run_codemods, files):
for error, _difflines in pool.imap_unordered(partial_run_codemods, files):
progress.advance(task)

if _difflines is not None:
difflines.append(_difflines)

if error is not None:
count_errors += 1
log_fp.writelines(error)

modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time]

if modified:
logger.info(f"Refactored {len(modified)} files.")
if modified and not diff:
console.log(f"Refactored {len(modified)} files.")

for _difflines in difflines:
color_diff(console, _difflines)

if count_errors > 0:
logger.info(f"Found {count_errors} errors. Please check the {log_file} file.")
console.log(f"Found {count_errors} errors. Please check the {log_file} file.")
else:
logger.info("Run successfully!")
console.log("Run successfully!")

if difflines:
raise Exit(1)


def run_codemods(
codemods: List[Type[ContextAwareTransformer]],
metadata_manager: FullRepoManager,
scratch: Dict[str, Any],
package: Path,
diff: bool,
filename: str,
) -> Union[str, None]:
) -> Tuple[Union[str, None], Union[List[str], None]]:
try:
module_and_package = calculate_module_and_package(str(package), filename)
context = CodemodContext(
Expand All @@ -134,8 +144,30 @@ def run_codemods(

output_code = input_tree.code
if code != output_code:
fp.write(output_code)
fp.truncate()
return None
if diff:
lines = difflib.unified_diff(
code.splitlines(keepends=True),
output_code.splitlines(keepends=True),
fromfile=filename,
tofile=filename,
)
return None, list(lines)
else:
fp.write(output_code)
fp.truncate()
return None, None
except Exception:
return f"An error happened on {filename}.\n{traceback.format_exc()}"
return f"An error happened on {filename}.\n{traceback.format_exc()}", None


def color_diff(console: Console, lines: Iterable[str]) -> None:
for line in lines:
line = line.rstrip("\n")
if line.startswith("+"):
console.print(line, style="green")
elif line.startswith("-"):
console.print(line, style="red")
elif line.startswith("^"):
console.print(line, style="blue")
else:
console.print(line, style="white")

0 comments on commit 7b83d4b

Please sign in to comment.