Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/mark custom types #116

Merged
merged 13 commits into from
Sep 18, 2023
43 changes: 41 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Bump Pydantic is a tool to help you migrate your code from Pydantic V1 to V2.
- [BP006: Replace `__root__` by `RootModel`](#bp006-replace-__root__-by-rootmodel)
- [BP007: Replace decorators](#bp007-replace-decorators)
- [BP008: Replace `con*` functions by `Annotated` versions](#bp008-replace-con-functions-by-annotated-versions)
- [BP009: Mark pydantic "protocol" functions in custom types with proper TODOs](bp009-mark-pydantic-protocol-functions-in-custom-types-with-proper-todos)

- [License](#license)

---
Expand Down Expand Up @@ -301,7 +303,44 @@ class User(BaseModel):
name: Annotated[str, StringConstraints(min_length=1)]
```

<!-- ### BP009: Replace `pydantic.parse_obj_as` by `pydantic.TypeAdapter`
### BP009: Mark Pydantic "protocol" functions in custom types with proper TODOs

- ✅ Mark `__get_validators__` as to be replaced by `__get_pydantic_core_schema__`.
- ✅ Mark `__modify_schema__` as to be replaced by `__get_pydantic_json_schema__`.

The following code will be transformed:

```py
class SomeThing:
@classmethod
def __get_validators__(cls):
yield from []

@classmethod
def __modify_schema__(cls, field_schema, field):
if field:
field_schema['example'] = "Weird example"
```

Into:

```py
class SomeThing:
@classmethod
# TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually.
# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.
def __get_validators__(cls):
yield from []

@classmethod
# TODO[pydantic]: We couldn't refactor `__modify_schema__`, please create the `__get_pydantic_json_schema__` manually.
# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.
def __modify_schema__(cls, field_schema, field):
if field:
field_schema['example'] = "Weird example"
```

<!-- ### BP010: Replace `pydantic.parse_obj_as` by `pydantic.TypeAdapter`

- ✅ Replace `pydantic.parse_obj_as(T, obj)` to `pydantic.TypeAdapter(T).validate_python(obj)`.

Expand Down Expand Up @@ -344,7 +383,7 @@ class Users(BaseModel):
users = TypeAdapter(Users).validate_python({'users': [{'name': 'John'}]})
``` -->

<!-- ### BP010: Replace `PyObject` by `ImportString`
<!-- ### BP011: Replace `PyObject` by `ImportString`

- ✅ Replace `PyObject` by `ImportString`.

Expand Down
6 changes: 6 additions & 0 deletions bump_pydantic/codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand
from bump_pydantic.codemods.con_func import ConFuncCallCommand
from bump_pydantic.codemods.custom_types import CustomTypeCodemod
from bump_pydantic.codemods.field import FieldCodemod
from bump_pydantic.codemods.replace_config import ReplaceConfigCodemod
from bump_pydantic.codemods.replace_generic_model import ReplaceGenericModelCommand
Expand All @@ -31,6 +32,8 @@ class Rule(str, Enum):
"""Replace `@validator` with `@field_validator`."""
BP008 = "BP008"
"""Replace `con*` functions by `Annotated` versions."""
BP009 = "BP009"
"""Mark Pydantic "protocol" functions in custom types with proper TODOs."""


def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
Expand Down Expand Up @@ -61,6 +64,9 @@ def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]
if Rule.BP007 not in disabled:
codemods.append(ValidatorCodemod)

if Rule.BP009 not in disabled:
codemods.append(CustomTypeCodemod)

# Those codemods need to be the last ones.
codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
return codemods
72 changes: 72 additions & 0 deletions bump_pydantic/codemods/custom_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import libcst as cst
from libcst import matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor

PREFIX_COMMENT = "# TODO[pydantic]: "
REFACTOR_COMMENT = f"{PREFIX_COMMENT}We couldn't refactor `{{old_name}}`, please create the `{{new_name}}` manually."
GET_VALIDATORS_COMMENT = REFACTOR_COMMENT.format(old_name="__get_validators__", new_name="__get_pydantic_core_schema__")
MODIFY_SCHEMA_COMMENT = REFACTOR_COMMENT.format(old_name="__modify_schema__", new_name="__get_pydantic_json_schema__")
COMMENT_BY_FUNC_NAME = {"__get_validators__": GET_VALIDATORS_COMMENT, "__modify_schema__": MODIFY_SCHEMA_COMMENT}
CHECK_LINK_COMMENT = "# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information."

GET_VALIDATORS_FUNCTION = m.FunctionDef(name=m.Name("__get_validators__"))
MODIFY_SCHEMA_FUNCTION = m.FunctionDef(name=m.Name("__modify_schema__"))


class CustomTypeCodemod(VisitorBasedCodemodCommand):
@m.leave(MODIFY_SCHEMA_FUNCTION | GET_VALIDATORS_FUNCTION)
def leave_modify_schema_func(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
for line in [*updated_node.leading_lines, *updated_node.lines_after_decorators]:
if m.matches(line, m.EmptyLine(comment=m.Comment(value=CHECK_LINK_COMMENT))):
return updated_node

comment = COMMENT_BY_FUNC_NAME[updated_node.name.value]
return updated_node.with_changes(
lines_after_decorators=[
*updated_node.lines_after_decorators,
cst.EmptyLine(comment=cst.Comment(value=(comment))),
cst.EmptyLine(comment=cst.Comment(value=(CHECK_LINK_COMMENT))),
]
)


if __name__ == "__main__":
import textwrap

from rich.console import Console

console = Console()

source = textwrap.dedent(
"""
class SomeThing:
@classmethod
def __get_validators__(cls):
yield from []
return

@classmethod
def __modify_schema__(
cls, field_schema: Dict[str, Any], field: Optional[ModelField]
):
if field:
field_schema['example'] = "Weird example"
"""
)
console.print(source)
console.print("=" * 80)

mod = cst.parse_module(source)
context = CodemodContext(filename="main.py")
wrapper = cst.MetadataWrapper(mod)
command = CustomTypeCodemod(context=context)
# console.print(mod)

mod = wrapper.visit(command)
wrapper = cst.MetadataWrapper(mod)
command = AddImportsVisitor(context=context) # type: ignore[assignment]
mod = wrapper.visit(command)
console.print(mod.code)
2 changes: 2 additions & 0 deletions tests/integration/cases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .base_settings import cases as base_settings_cases
from .con_func import cases as con_func_cases
from .config_to_model import cases as config_to_model_cases
from .custom_types import cases as custom_types_cases
from .field import cases as generic_model_cases
from .folder_inside_folder import cases as folder_inside_folder_cases
from .is_base_model import cases as is_base_model_cases
Expand All @@ -28,6 +29,7 @@
*folder_inside_folder_cases,
*unicode_cases,
*con_func_cases,
*custom_types_cases,
]
before = Folder("project", *[case.source for case in cases])
expected = Folder("project", *[case.expected for case in cases])
59 changes: 59 additions & 0 deletions tests/integration/cases/custom_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from ..case import Case
from ..file import File

cases = [
Case(
name="Mark __get_validators__",
source=File(
"mark_get_validators.py",
content=[
"class SomeThing:",
" @classmethod",
" def __get_validators__(cls):",
" yield from []",
" return",
],
),
expected=File(
"mark_get_validators.py",
content=[
"class SomeThing:",
" @classmethod",
" # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually.", # noqa: E501
" # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.",
" def __get_validators__(cls):",
" yield from []",
" return",
],
),
),
Case(
name="Mark __modify_schema__",
source=File(
"mark_modify_schema.py",
content=[
"class SomeThing:",
" @classmethod",
" def __modify_schema__(",
" cls, field_schema: Dict[str, Any], field: Optional[ModelField]",
" ):",
" if field:",
" field_schema['example'] = \"Weird example\"",
],
),
expected=File(
"mark_modify_schema.py",
content=[
"class SomeThing:",
" @classmethod",
" # TODO[pydantic]: We couldn't refactor `__modify_schema__`, please create the `__get_pydantic_json_schema__` manually.", # noqa: E501
" # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.",
" def __modify_schema__(",
" cls, field_schema: Dict[str, Any], field: Optional[ModelField]",
" ):",
" if field:",
" field_schema['example'] = \"Weird example\"",
],
),
),
]
76 changes: 76 additions & 0 deletions tests/unit/test_custom_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from libcst.codemod import CodemodTest

from bump_pydantic.codemods.custom_types import CustomTypeCodemod


class TestArbitraryClassCommand(CodemodTest):
TRANSFORM = CustomTypeCodemod

maxDiff = None

def test_mark_get_validators(self) -> None:
before = """
class SomeThing:
@classmethod
def __get_validators__(cls):
yield from []
return
"""
after = """
class SomeThing:
@classmethod
# TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually.
# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.
def __get_validators__(cls):
yield from []
return
""" # noqa: E501
self.assertCodemod(before, after)

def test_mark_modify_schema(self) -> None:
before = """
class SomeThing:
@classmethod
def __modify_schema__(
cls, field_schema: Dict[str, Any], field: Optional[ModelField]
):
if field:
field_schema['example'] = "Weird example"
"""
after = """
class SomeThing:
@classmethod
# TODO[pydantic]: We couldn't refactor `__modify_schema__`, please create the `__get_pydantic_json_schema__` manually.
# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.
def __modify_schema__(
cls, field_schema: Dict[str, Any], field: Optional[ModelField]
):
if field:
field_schema['example'] = "Weird example"
""" # noqa: E501
self.assertCodemod(before, after)

def test_already_commented(self) -> None:
before = """
class SomeThing:
@classmethod
# TODO[pydantic]: We couldn't refactor `__modify_schema__`, please create the `__get_pydantic_json_schema__` manually.
# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.
def __modify_schema__(
cls, field_schema: Dict[str, Any], field: Optional[ModelField]
):
if field:
field_schema['example'] = "Weird example"
""" # noqa: E501
after = """
class SomeThing:
@classmethod
# TODO[pydantic]: We couldn't refactor `__modify_schema__`, please create the `__get_pydantic_json_schema__` manually.
# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.
def __modify_schema__(
cls, field_schema: Dict[str, Any], field: Optional[ModelField]
):
if field:
field_schema['example'] = "Weird example"
""" # noqa: E501
self.assertCodemod(before, after)