Skip to content

Commit

Permalink
Use mypy visitor
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Jul 6, 2023
1 parent e5e9148 commit 1ea7b36
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 275 deletions.
8 changes: 4 additions & 4 deletions bump_pydantic/codemods/add_default_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.metadata import FullyQualifiedNameProvider, QualifiedName

from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor
from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY


class AddDefaultNoneCommand(VisitorBasedCodemodCommand):
Expand Down Expand Up @@ -49,7 +49,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
return None

fqn: QualifiedName = next(iter(fqn_set)) # type: ignore
if fqn.name in self.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]:
if self.context.scratch[CONTEXT_KEY].get(fqn.name, False):
self.inside_base_model = True

def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
Expand Down Expand Up @@ -120,8 +120,8 @@ class Bar(Foo):
wrapper = mrg.get_metadata_wrapper_for_path(module)
context = CodemodContext(wrapper=wrapper)

command = ClassDefVisitor(context=context)
mod = wrapper.visit(command)
# classes = run_mypy_visitor(context=context)
# mod = wrapper.visit(command)

command = AddDefaultNoneCommand(context=context) # type: ignore[assignment]
mod = wrapper.visit(command)
Expand Down
144 changes: 0 additions & 144 deletions bump_pydantic/codemods/class_def_visitor.py

This file was deleted.

55 changes: 55 additions & 0 deletions bump_pydantic/codemods/mypy_visitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import annotations

import sys
from argparse import ArgumentParser

from mypy.build import build
from mypy.main import process_options
from mypy.nodes import ClassDef
from mypy.traverser import TraverserVisitor
from rich.console import Console

CONTEXT_KEY = "mypy_visitor"


class MyPyVisitor(TraverserVisitor):
def __init__(self) -> None:
super().__init__()
self.classes: dict[str, bool] = {}

def visit_class_def(self, o: ClassDef) -> None:
super().visit_class_def(o)
self.classes[o.fullname] = o.info.has_base("pydantic.main.BaseModel")


def run_mypy_visitor(arg_files: list[str], console: Console | None = None) -> dict[str, bool]:
console = console or Console()
files, opt = process_options(arg_files, stdout=sys.stdout, stderr=sys.stderr)

opt.export_types = True
opt.incremental = True
opt.fine_grained_incremental = True
opt.cache_fine_grained = True
opt.allow_redefinition = True
opt.local_partial_types = True

console.print("Running MyPy - this may take a while...")
result = build(files, opt, stdout=sys.stdout, stderr=sys.stderr)

visitor = MyPyVisitor()
classes: dict[str, bool] = {}

for file in files:
tree = result.graph[file.module].tree
if tree:
tree.accept(visitor=visitor)
classes.update(visitor.classes)
return classes


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("files", nargs="+")
args = parser.parse_args()

run_mypy_visitor(args.files)
61 changes: 5 additions & 56 deletions bump_pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Set, Type, TypeVar, Union
from typing import Any, Callable, Dict, Iterable, List, Type, TypeVar, Union

import libcst as cst
from libcst.codemod import CodemodContext, ContextAwareTransformer
Expand All @@ -18,7 +18,7 @@

from bump_pydantic import __version__
from bump_pydantic.codemods import Rule, gather_codemods
from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor
from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY, run_mypy_visitor

app = Typer(
help="Convert Pydantic from V1 to V2 ♻️",
Expand Down Expand Up @@ -61,42 +61,8 @@ def main(
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type]
metadata_manager.resolve_cache()

scratch: dict[str, Any] = {}
with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Looking for Pydantic Models...", total=len(files))

queue: List[str] = [files[0]]
visited: Set[str] = set()

while queue:
# Queue logic
filename = queue.pop()
visited.add(filename)
progress.advance(task)

# Visitor logic
code = Path(filename).read_text()
module = cst.parse_module(code)
module_and_package = calculate_module_and_package(str(package), filename)

context = CodemodContext(
metadata_manager=metadata_manager,
filename=filename,
full_module_name=module_and_package.name,
full_package_name=module_and_package.package,
scratch=scratch,
)
visitor = ClassDefVisitor(context=context)
visitor.transform_module(module)

# Queue logic
next_file = visitor.next_file(visited)
if next_file is not None:
queue.append(next_file)

missing_files = set(files) - visited
if not queue and missing_files:
queue.append(next(iter(missing_files)))
classes = run_mypy_visitor(files, console=console)
scratch: dict[str, Any] = {CONTEXT_KEY: classes}

start_time = time.time()

Expand All @@ -108,7 +74,7 @@ def main(
with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Executing codemods...", total=len(files))
with multiprocessing.Pool() as pool, log_ctx_mgr as log_fp: # type: ignore[attr-defined]
for error_msg in pool.imap_unordered(partial_run_codemods, files):
for error_msg in pool.imap(partial_run_codemods, files):
progress.advance(task)
if error_msg is None:
continue
Expand Down Expand Up @@ -139,23 +105,6 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Union[T, Iterable[str]]:
return wrapper


@capture_exception
def visit_class_def(metadata_manager: FullRepoManager, package: Path, filename: str) -> Dict[str, Any]:
code = Path(filename).read_text()
module = cst.parse_module(code)
module_and_package = calculate_module_and_package(str(package), filename)

context = CodemodContext(
metadata_manager=metadata_manager,
filename=filename,
full_module_name=module_and_package.name,
full_package_name=module_and_package.package,
)
visitor = ClassDefVisitor(context=context)
visitor.transform_module(module)
return context.scratch


@capture_exception
def run_codemods(
codemods: List[Type[ContextAwareTransformer]],
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
"Framework :: Pydantic",
]
dependencies = ["typer>=0.7.0", "libcst", "rich", "typing_extensions"]
dependencies = ["typer>=0.7.0", "libcst", "rich", "typing_extensions", "mypy"]

[project.urls]
Documentation = "https://github.com/pydantic/bump-pydantic#readme"
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def test_command_line(tmp_path: Path, before: Folder, expected: Folder) -> None:
before.create_structure(root=Path(td))

result = runner.invoke(app, [before.name])
print(result.output)
assert result.exit_code == 0, result.output
# assert result.output.endswith("Refactored 4 files.\n")

Expand Down
7 changes: 4 additions & 3 deletions tests/unit/test_add_default_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from libcst.testing.utils import UnitTest

from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand
from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor
from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY, run_mypy_visitor


@pytest.mark.skip(reason="The file needs to exists for the test to pass.")
class TestClassDefVisitor(UnitTest):
def add_default_none(self, file_path: str, code: str) -> cst.Module:
mod = MetadataWrapper(
Expand All @@ -24,8 +25,8 @@ def add_default_none(self, file_path: str, code: str) -> cst.Module:
)
mod.resolve_many(AddDefaultNoneCommand.METADATA_DEPENDENCIES)
context = CodemodContext(wrapper=mod)
instance = ClassDefVisitor(context=context)
mod.visit(instance)
classes = run_mypy_visitor(arg_files=[file_path])
context.scratch.update({CONTEXT_KEY: classes})

instance = AddDefaultNoneCommand(context=context) # type: ignore[assignment]
return mod.visit(instance)
Expand Down
Loading

0 comments on commit 1ea7b36

Please sign in to comment.