Skip to content

Commit

Permalink
Add logs
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Jul 7, 2023
1 parent 1ea7b36 commit 5bfa0d0
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 91 deletions.
5 changes: 1 addition & 4 deletions bump_pydantic/codemods/mypy_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from mypy.main import process_options
from mypy.nodes import ClassDef
from mypy.traverser import TraverserVisitor
from rich.console import Console

CONTEXT_KEY = "mypy_visitor"

Expand All @@ -22,8 +21,7 @@ def visit_class_def(self, o: ClassDef) -> None:
self.classes[o.fullname] = o.info.has_base("pydantic.main.BaseModel")


def run_mypy_visitor(arg_files: list[str], console: Console | None = None) -> dict[str, bool]:
console = console or Console()
def run_mypy_visitor(arg_files: list[str]) -> dict[str, bool]:
files, opt = process_options(arg_files, stdout=sys.stdout, stderr=sys.stderr)

opt.export_types = True
Expand All @@ -33,7 +31,6 @@ def run_mypy_visitor(arg_files: list[str], console: Console | None = None) -> di
opt.allow_redefinition = True
opt.local_partial_types = True

console.print("Running MyPy - this may take a while...")
result = build(files, opt, stdout=sys.stdout, stderr=sys.stderr)

visitor = MyPyVisitor()
Expand Down
143 changes: 56 additions & 87 deletions bump_pydantic/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import difflib
import functools
import logging
import multiprocessing
import os
import time
from contextlib import nullcontext
import traceback
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Type, TypeVar, Union
from typing import Any, Dict, List, Type, TypeVar, Union

import libcst as cst
from libcst.codemod import CodemodContext, ContextAwareTransformer
from libcst.helpers import calculate_module_and_package
from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider, ScopeProvider
from rich.console import Console
from rich.logging import RichHandler
from rich.progress import Progress
from typer import Argument, Exit, Option, Typer, echo
from typing_extensions import ParamSpec
Expand All @@ -30,6 +30,10 @@
T = TypeVar("T")


logging.basicConfig(level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()])
logger = logging.getLogger("bump_pydantic")


def version_callback(value: bool):
if value:
echo(f"bump-pydantic version: {__version__}")
Expand All @@ -39,9 +43,8 @@ def version_callback(value: bool):
@app.callback()
def main(
package: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False),
diff: bool = Option(False, help="Show diff instead of applying changes."),
disable: List[Rule] = Option(default=[], help="Disable a rule."),
log_file: Union[Path, None] = Option(None, help="Log file to write to."),
log_file: Path = Option("log.txt", help="Log errors to this file."),
version: bool = Option(
None,
"--version",
Expand All @@ -50,117 +53,83 @@ def main(
help="Show the version and exit.",
),
):
logger.info("Start bump-pydantic.")
# NOTE: LIBCST_PARSER_TYPE=native is required according to https://github.com/Instagram/LibCST/issues/487.
os.environ["LIBCST_PARSER_TYPE"] = "native"

console = Console()
files_str = list(package.glob("**/*.py"))
files = [str(file.relative_to(".")) for file in files_str]
logger.info(f"Found {len(files)} files to process.")

providers = {FullyQualifiedNameProvider, ScopeProvider}
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type]
metadata_manager.resolve_cache()

classes = run_mypy_visitor(files, console=console)
logger.info("Running mypy to get type information. This may take a while...")
classes = run_mypy_visitor(files)
scratch: dict[str, Any] = {CONTEXT_KEY: classes}
logger.info("Finished mypy.")

start_time = time.time()

codemods = gather_codemods(disabled=disable)

log_ctx_mgr = log_file.open("a+") if log_file else nullcontext()
partial_run_codemods = functools.partial(run_codemods, codemods, metadata_manager, scratch, package, diff)

log_fp = log_file.open("a+")
partial_run_codemods = functools.partial(run_codemods, codemods, metadata_manager, scratch, package)
with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Executing codemods...", total=len(files))
with multiprocessing.Pool() as pool, log_ctx_mgr as log_fp: # type: ignore[attr-defined]
for error_msg in pool.imap(partial_run_codemods, files):
count_errors = 0
with multiprocessing.Pool() as pool:
for error in pool.imap_unordered(partial_run_codemods, files):
progress.advance(task)
if error_msg is None:
continue

if log_fp is None:
color_diff(console, error_msg)
else:
log_fp.writelines(error_msg)

if log_fp:
log_fp.write("Run successfully!\n")
if error is not None:
count_errors += 1
log_fp.writelines(error)

modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time]
if modified:
print(f"Refactored {len(modified)} files.")

if modified:
logger.info(f"Refactored {len(modified)} files.")

def capture_exception(func: Callable[P, T]) -> Callable[P, Union[T, Iterable[str]]]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Union[T, Iterable[str]]:
try:
return func(*args, **kwargs)
except Exception as exc:
func_args = [repr(arg) for arg in args]
func_kwargs = [f"{key}={repr(value)}" for key, value in kwargs.items()]
return [f"{func.__name__}({', '.join(func_args + func_kwargs)})\n{exc}"]

return wrapper
if count_errors > 0:
logger.info(f"Found {count_errors} errors. Please check the {log_file} file.")
else:
logger.info("Run successfully!")


@capture_exception
def run_codemods(
codemods: List[Type[ContextAwareTransformer]],
metadata_manager: FullRepoManager,
scratch: Dict[str, Any],
package: Path,
diff: bool,
filename: str,
) -> Union[List[str], None]:
module_and_package = calculate_module_and_package(str(package), filename)
context = CodemodContext(
metadata_manager=metadata_manager,
filename=filename,
full_module_name=module_and_package.name,
full_package_name=module_and_package.package,
)
context.scratch.update(scratch)

file_path = Path(filename)
with file_path.open("r+") as fp:
code = fp.read()
fp.seek(0)

input_tree = cst.parse_module(code)

for codemod in codemods:
transformer = codemod(context=context)

output_tree = transformer.transform_module(input_tree)
input_tree = output_tree

output_code = input_tree.code
if code != output_code:
if diff:
lines = difflib.unified_diff(
code.splitlines(keepends=True),
output_code.splitlines(keepends=True),
fromfile=filename,
tofile=filename,
)
return list(lines)
else:
) -> Union[str, None]:
try:
module_and_package = calculate_module_and_package(str(package), filename)
context = CodemodContext(
metadata_manager=metadata_manager,
filename=filename,
full_module_name=module_and_package.name,
full_package_name=module_and_package.package,
)
context.scratch.update(scratch)

file_path = Path(filename)
with file_path.open("r+") as fp:
code = fp.read()
fp.seek(0)

input_tree = cst.parse_module(code)

for codemod in codemods:
transformer = codemod(context=context)
output_tree = transformer.transform_module(input_tree)
input_tree = output_tree

output_code = input_tree.code
if code != output_code:
fp.write(output_code)
fp.truncate()

return None


def color_diff(console: Console, lines: Iterable[str]) -> None:
for line in lines:
line = line.rstrip("\n")
if line.startswith("+"):
console.print(line, style="green")
elif line.startswith("-"):
console.print(line, style="red")
elif line.startswith("^"):
console.print(line, style="blue")
else:
console.print(line, style="white")
return None
except Exception:
return f"An error happened on {filename}.\n{traceback.format_exc()}"

0 comments on commit 5bfa0d0

Please sign in to comment.