Skip to content

Commit

Permalink
♻️ Replace MypyVisitor by ClassDefVisitor (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Jul 17, 2023
1 parent 67c8cc3 commit 082cc45
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 68 deletions.
7 changes: 2 additions & 5 deletions bump_pydantic/codemods/add_default_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.metadata import FullyQualifiedNameProvider, QualifiedName

from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY
from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor


class AddDefaultNoneCommand(VisitorBasedCodemodCommand):
Expand Down Expand Up @@ -49,7 +49,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
return None

fqn: QualifiedName = next(iter(fqn_set)) # type: ignore
if self.context.scratch[CONTEXT_KEY].get(fqn.name, False):
if fqn.name in self.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]:
self.inside_base_model = True

def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
Expand Down Expand Up @@ -138,9 +138,6 @@ class Bar(Foo):
wrapper = mrg.get_metadata_wrapper_for_path(module)
context = CodemodContext(wrapper=wrapper)

# classes = run_mypy_visitor(context=context)
# mod = wrapper.visit(command)

command = AddDefaultNoneCommand(context=context) # type: ignore[assignment]
mod = wrapper.visit(command)
print(mod.code)
147 changes: 147 additions & 0 deletions bump_pydantic/codemods/class_def_visitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
There are two objects in the visitor:
1. `base_model_cls` (Set[str]): Set of classes that are BaseModel based.
2. `cls` (Dict[str, Set[str]]): A dictionary mapping each class definition to a set of base classes.
`base_model_cls` accumulates on each iteration.
`cls` also accumulates on each iteration, but it's also partially solved:
1. Check if the module visited is a prefix of any `cls.keys()`.
1.1. If it is, and if any `base_model_cls` is found, remove from `cls`, and add to `base_model_cls`.
1.2. If it's not, it continues on the `cls`
"""
from __future__ import annotations

from collections import defaultdict
from typing import Set, cast

import libcst as cst
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.metadata import FullyQualifiedNameProvider, QualifiedName


class ClassDefVisitor(VisitorBasedCodemodCommand):
METADATA_DEPENDENCIES = {FullyQualifiedNameProvider}

BASE_MODEL_CONTEXT_KEY = "base_model_cls"
NO_BASE_MODEL_CONTEXT_KEY = "no_base_model_cls"
CLS_CONTEXT_KEY = "cls"

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)
self.module_fqn: None | QualifiedName = None

self.context.scratch.setdefault(
self.BASE_MODEL_CONTEXT_KEY,
{"pydantic.BaseModel", "pydantic.main.BaseModel"},
)
self.context.scratch.setdefault(self.NO_BASE_MODEL_CONTEXT_KEY, set())
self.context.scratch.setdefault(self.CLS_CONTEXT_KEY, defaultdict(set))

def visit_ClassDef(self, node: cst.ClassDef) -> None:
fqn_set = self.get_metadata(FullyQualifiedNameProvider, node)

if not fqn_set:
return None

fqn: QualifiedName = next(iter(fqn_set)) # type: ignore

if not node.bases:
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name)

for arg in node.bases:
base_fqn_set = self.get_metadata(FullyQualifiedNameProvider, arg.value)
base_fqn_set = base_fqn_set or set()

for base_fqn in cast(Set[QualifiedName], iter(base_fqn_set)): # type: ignore
if base_fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY]:
self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(fqn.name)
elif base_fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY]:
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name)

# In case we have the following scenario:
# class A(B): ...
# class B(BaseModel): ...
# class D(C): ...
# class C: ...
# We want to disambiguate `A` as soon as we see `B` is a `BaseModel`.
if (
fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY]
and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY]
):
for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name):
self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(parent_class)

# In case we have the following scenario:
# class A(B): ...
# class B(BaseModel): ...
# class D(C): ...
# class C: ...
# We want to disambiguate `D` as soon as we see `C` is NOT a `BaseModel`.
if (
fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY]
and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY]
):
for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name):
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(parent_class)

# In case we have the following scenario:
# class A(B): ...
# ...And B is not known.
# We want to make sure that B -> A is added to the `cls` context, so if we find B later,
# we can disambiguate.
if fqn.name not in (
*self.context.scratch[self.BASE_MODEL_CONTEXT_KEY],
*self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY],
):
for base_fqn in cast(Set[QualifiedName], base_fqn_set):
self.context.scratch[self.CLS_CONTEXT_KEY][base_fqn.name].add(fqn.name)

# TODO: Implement this if needed...
def next_file(self, visited: set[str]) -> str | None:
return None


if __name__ == "__main__":
import os
import textwrap
from pathlib import Path
from tempfile import TemporaryDirectory

from libcst.metadata import FullRepoManager
from rich.pretty import pprint

with TemporaryDirectory(dir=os.getcwd()) as tmpdir:
package_dir = f"{tmpdir}/package"
os.mkdir(package_dir)
module_path = f"{package_dir}/a.py"
with open(module_path, "w") as f:
content = textwrap.dedent(
"""
from pydantic import BaseModel
class Foo(BaseModel):
a: str
class Bar(Foo):
b: str
class Potato:
...
class Spam(Potato):
...
foo = Foo(a="text")
foo.dict()
"""
)
f.write(content)
module = str(Path(module_path).relative_to(tmpdir))
mrg = FullRepoManager(tmpdir, {module}, providers={FullyQualifiedNameProvider})
wrapper = mrg.get_metadata_wrapper_for_path(module)
context = CodemodContext(wrapper=wrapper)
command = ClassDefVisitor(context=context)
mod = wrapper.visit(command)
pprint(context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY])
pprint(context.scratch[ClassDefVisitor.NO_BASE_MODEL_CONTEXT_KEY])
pprint(context.scratch[ClassDefVisitor.CLS_CONTEXT_KEY])
52 changes: 0 additions & 52 deletions bump_pydantic/codemods/mypy_visitor.py

This file was deleted.

43 changes: 37 additions & 6 deletions bump_pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import traceback
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Iterable, List, Set, Tuple, Type, TypeVar, Union

import libcst as cst
from libcst.codemod import CodemodContext, ContextAwareTransformer
Expand All @@ -18,7 +18,7 @@

from bump_pydantic import __version__
from bump_pydantic.codemods import Rule, gather_codemods
from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY, run_mypy_visitor
from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor

app = Typer(
help="Convert Pydantic from V1 to V2 ♻️",
Expand Down Expand Up @@ -69,10 +69,41 @@ def main(
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type]
metadata_manager.resolve_cache()

console.log("Running mypy to get type information. This may take a while...")
classes = run_mypy_visitor(files)
scratch: dict[str, Any] = {CONTEXT_KEY: classes}
console.log("Finished mypy.")
scratch: dict[str, Any] = {}
with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Looking for Pydantic Models...", total=len(files))
queue: List[str] = [files[0]]
visited: Set[str] = set()

while queue:
# Queue logic
filename = queue.pop()
visited.add(filename)
progress.advance(task)

# Visitor logic
code = Path(filename).read_text()
module = cst.parse_module(code)
module_and_package = calculate_module_and_package(str(package), filename)

context = CodemodContext(
metadata_manager=metadata_manager,
filename=filename,
full_module_name=module_and_package.name,
full_package_name=module_and_package.package,
scratch=scratch,
)
visitor = ClassDefVisitor(context=context)
visitor.transform_module(module)

# Queue logic
next_file = visitor.next_file(visited)
if next_file is not None:
queue.append(next_file)

missing_files = set(files) - visited
if not queue and missing_files:
queue.append(next(iter(missing_files)))

start_time = time.time()

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
"Framework :: Pydantic",
]
dependencies = ["typer>=0.7.0", "libcst", "rich", "typing_extensions", "mypy"]
dependencies = ["typer>=0.7.0", "libcst", "rich", "typing_extensions"]

[project.urls]
Documentation = "https://github.com/pydantic/bump-pydantic#readme"
Expand Down
7 changes: 3 additions & 4 deletions tests/unit/test_add_default_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from libcst.testing.utils import UnitTest

from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand
from bump_pydantic.codemods.mypy_visitor import CONTEXT_KEY, run_mypy_visitor
from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor


@pytest.mark.skip(reason="The file needs to exists for the test to pass.")
class TestClassDefVisitor(UnitTest):
def add_default_none(self, file_path: str, code: str) -> cst.Module:
mod = MetadataWrapper(
Expand All @@ -25,8 +24,8 @@ def add_default_none(self, file_path: str, code: str) -> cst.Module:
)
mod.resolve_many(AddDefaultNoneCommand.METADATA_DEPENDENCIES)
context = CodemodContext(wrapper=mod)
classes = run_mypy_visitor(arg_files=[file_path])
context.scratch.update({CONTEXT_KEY: classes})
instance = ClassDefVisitor(context=context)
mod.visit(instance)

instance = AddDefaultNoneCommand(context=context) # type: ignore[assignment]
return mod.visit(instance)
Expand Down

0 comments on commit 082cc45

Please sign in to comment.