-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Support GenericModel to BaseModel replacement (#12)
- Loading branch information
Showing
8 changed files
with
215 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from __future__ import annotations | ||
|
||
import libcst as cst | ||
import libcst.matchers as m | ||
from libcst.codemod import VisitorBasedCodemodCommand | ||
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor | ||
|
||
GENERIC_MODEL_ARG = m.Arg(value=m.Name("GenericModel")) | m.Arg( | ||
value=m.Attribute(value=m.Name("generics"), attr=m.Name("GenericModel")) | ||
) | ||
|
||
|
||
class ReplaceGenericModelCommand(VisitorBasedCodemodCommand): | ||
@m.leave(m.ClassDef(bases=[m.ZeroOrMore(), GENERIC_MODEL_ARG, m.ZeroOrMore()])) | ||
def leave_generic_model(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: | ||
RemoveImportsVisitor.remove_unused_import(context=self.context, module="pydantic.generics", obj="GenericModel") | ||
RemoveImportsVisitor.remove_unused_import(context=self.context, module="pydantic", obj="generics") | ||
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="BaseModel") | ||
return updated_node.with_changes( | ||
bases=[ | ||
base if not m.matches(base, GENERIC_MODEL_ARG) else cst.Arg(value=cst.Name("BaseModel")) | ||
for base in updated_node.bases | ||
] | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
import textwrap | ||
|
||
from rich.console import Console | ||
|
||
console = Console() | ||
|
||
source = textwrap.dedent( | ||
""" | ||
from typing import Generic, TypeVar | ||
from pydantic.generics import GenericModel | ||
T = TypeVar("T") | ||
class Potato(GenericModel, Generic[T]): | ||
... | ||
""" | ||
) | ||
console.print(source) | ||
# console.print("=" * 80) | ||
|
||
# mod = cst.parse_module(source) | ||
# context = CodemodContext(filename="main.py") | ||
|
||
# wrapper = cst.MetadataWrapper(mod) | ||
# command = ReplaceGenericModelCommand(context=context) | ||
# mod = wrapper.visit(command) | ||
|
||
# wrapper = cst.MetadataWrapper(mod) | ||
# command = RemoveImportsVisitor(context=context) # type: ignore[assignment] | ||
# mod = wrapper.visit(command) | ||
|
||
# wrapper = cst.MetadataWrapper(mod) | ||
# command = AddImportsVisitor(context=context) # type: ignore[assignment] | ||
# mod = wrapper.visit(command) | ||
# console.print(mod.code) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from typing import Generic, TypeVar | ||
|
||
from pydantic.generics import GenericModel | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
class User(GenericModel, Generic[T]): | ||
name: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from libcst.codemod import CodemodTest | ||
|
||
from bump_pydantic.codemods.replace_generic_model import ReplaceGenericModelCommand | ||
|
||
|
||
class TestReplaceGenericModelCommand(CodemodTest): | ||
TRANSFORM = ReplaceGenericModelCommand | ||
|
||
def test_noop(self) -> None: | ||
code = """ | ||
from typing import Generic, TypeVar | ||
T = TypeVar("T") | ||
class Potato(Generic[T]): | ||
... | ||
""" | ||
self.assertCodemod(code, code) | ||
|
||
def test_generic_model(self) -> None: | ||
before = """ | ||
from typing import TypeVar | ||
from pydantic.generics import GenericModel | ||
T = TypeVar("T") | ||
class Potato(GenericModel, Generic[T]): | ||
... | ||
""" | ||
after = """ | ||
from typing import TypeVar | ||
from pydantic import BaseModel | ||
T = TypeVar("T") | ||
class Potato(BaseModel, Generic[T]): | ||
... | ||
""" | ||
self.assertCodemod(before, after) | ||
|
||
def test_generic_model_multiple_bases(self) -> None: | ||
before = """ | ||
from typing import TypeVar | ||
from pydantic.generics import GenericModel | ||
T = TypeVar("T") | ||
class Potato(GenericModel, Generic[T], object): | ||
... | ||
""" | ||
after = """ | ||
from typing import TypeVar | ||
from pydantic import BaseModel | ||
T = TypeVar("T") | ||
class Potato(BaseModel, Generic[T], object): | ||
... | ||
""" | ||
self.assertCodemod(before, after) | ||
|
||
def test_generic_model_second_base(self) -> None: | ||
before = """ | ||
from typing import TypeVar | ||
from pydantic.generics import GenericModel | ||
T = TypeVar("T") | ||
class Potato(object, GenericModel, Generic[T]): | ||
... | ||
""" | ||
after = """ | ||
from typing import TypeVar | ||
from pydantic import BaseModel | ||
T = TypeVar("T") | ||
class Potato(object, BaseModel, Generic[T]): | ||
... | ||
""" | ||
self.assertCodemod(before, after) | ||
|
||
def test_generic_model_from_pydantic_import_generics(self) -> None: | ||
before = """ | ||
from typing import TypeVar | ||
from pydantic import generics | ||
T = TypeVar("T") | ||
class Potato(generics.GenericModel, Generic[T]): | ||
... | ||
""" | ||
after = """ | ||
from typing import TypeVar | ||
from pydantic import BaseModel | ||
T = TypeVar("T") | ||
class Potato(BaseModel, Generic[T]): | ||
... | ||
""" | ||
self.assertCodemod(before, after) |