Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/mark custom types #116

Merged
merged 13 commits into from
Sep 18, 2023
51 changes: 49 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Bump Pydantic is a tool to help you migrate your code from Pydantic V1 to V2.
- [BP006: Replace `__root__` by `RootModel`](#bp006-replace-__root__-by-rootmodel)
- [BP007: Replace decorators](#bp007-replace-decorators)
- [BP008: Replace `con*` functions by `Annotated` versions](#bp008-replace-con-functions-by-annotated-versions)
- [BP009: Mark pydantic "protocol" functions in custom types with proper TODOs](bp009-mark-pydantic-protocol-functions-in-custom-types-with-proper-todos)

- [License](#license)

---
Expand Down Expand Up @@ -301,7 +303,50 @@ class User(BaseModel):
name: Annotated[str, StringConstraints(min_length=1)]
```

<!-- ### BP009: Replace `pydantic.parse_obj_as` by `pydantic.TypeAdapter`
### BP009: Mark pydantic "protocol" functions in custom types with proper TODOs
Kludex marked this conversation as resolved.
Show resolved Hide resolved

- ✅ Mark `__get_validators__` as to be replaced by `__get_pydantic_core_schema__`.
- ✅ Mark `__modify_schema__` as to be replaced by `__get_pydantic_json_schema__`.

The following code will be transformed:

```py
class SomeThing:
@classmethod
def __get_validators__(cls):
yield from []
return
Kludex marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def __modify_schema__(
cls, field_schema: Dict[str, Any], field: Optional[ModelField]
):
Kludex marked this conversation as resolved.
Show resolved Hide resolved
if field:
field_schema['example'] = "Weird example"
```

Into:

```py
class SomeThing:
@classmethod
# TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually.
# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.
def __get_validators__(cls):
yield from []
return
Kludex marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
# TODO[pydantic]: We couldn't refactor `__modify_schema__`, please create the `__get_pydantic_json_schema__` manually.
# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.
def __modify_schema__(
cls, field_schema: Dict[str, Any], field: Optional[ModelField]
):
Kludex marked this conversation as resolved.
Show resolved Hide resolved
if field:
field_schema['example'] = "Weird example"
```

<!-- ### BP010: Replace `pydantic.parse_obj_as` by `pydantic.TypeAdapter`

- ✅ Replace `pydantic.parse_obj_as(T, obj)` to `pydantic.TypeAdapter(T).validate_python(obj)`.

Expand Down Expand Up @@ -344,7 +389,7 @@ class Users(BaseModel):
users = TypeAdapter(Users).validate_python({'users': [{'name': 'John'}]})
``` -->

<!-- ### BP010: Replace `PyObject` by `ImportString`
<!-- ### BP011: Replace `PyObject` by `ImportString`

- ✅ Replace `PyObject` by `ImportString`.

Expand All @@ -368,6 +413,8 @@ class User(BaseModel):
name: ImportString
``` -->



Kludex marked this conversation as resolved.
Show resolved Hide resolved
---

## License
Expand Down
6 changes: 6 additions & 0 deletions bump_pydantic/codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand
from bump_pydantic.codemods.con_func import ConFuncCallCommand
from bump_pydantic.codemods.custom_types import CustomTypeCodemod
from bump_pydantic.codemods.field import FieldCodemod
from bump_pydantic.codemods.replace_config import ReplaceConfigCodemod
from bump_pydantic.codemods.replace_generic_model import ReplaceGenericModelCommand
Expand All @@ -31,6 +32,8 @@ class Rule(str, Enum):
"""Replace `@validator` with `@field_validator`."""
BP008 = "BP008"
"""Replace `con*` functions by `Annotated` versions."""
BP009 = "BP009"
"""Mark pydantic 'protocol' functions in custom types with proper TODOs"""
Kludex marked this conversation as resolved.
Show resolved Hide resolved


def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
Expand Down Expand Up @@ -61,6 +64,9 @@ def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]
if Rule.BP007 not in disabled:
codemods.append(ValidatorCodemod)

if Rule.BP009 not in disabled:
codemods.append(CustomTypeCodemod)

# Those codemods need to be the last ones.
codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
return codemods
101 changes: 101 additions & 0 deletions bump_pydantic/codemods/custom_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import List

import libcst as cst
from libcst import matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor

PREFIX_COMMENT = "# TODO[pydantic]: "
REFACTOR_COMMENT = f"{PREFIX_COMMENT}We couldn't refactor `{{old_name}}`, please create the `{{new_name}}` manually."
GET_VALIDATORS_COMMENT = REFACTOR_COMMENT.format(old_name="__get_validators__", new_name="__get_pydantic_core_schema__")
MODIFY_SCHEMA_COMMENT = REFACTOR_COMMENT.format(old_name="__modify_schema__", new_name="__get_pydantic_json_schema__")
CHECK_LINK_COMMENT = "# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information."

GET_VALIDATORS_FUNCTION = m.FunctionDef(name=m.Name("__get_validators__"))
MODIFY_SCHEMA_FUNCTION = m.FunctionDef(name=m.Name("__modify_schema__"))


class CustomTypeCodemod(VisitorBasedCodemodCommand):
def __init__(self, context: CodemodContext) -> None:
super().__init__(context)

self._already_modified = False
self._should_add_comment = False
self._has_comment = False
self._args: List[cst.Arg] = []

@m.visit(GET_VALIDATORS_FUNCTION)
def visit_get_validators_func(self, node: cst.FunctionDef) -> None:
for line in node.leading_lines:
if m.matches(line, m.EmptyLine(comment=m.Comment(value=CHECK_LINK_COMMENT))):
self._has_comment = True

@m.leave(MODIFY_SCHEMA_FUNCTION)
def leave_modify_schema_func(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
if self._has_comment:
self._has_comment = False
return updated_node

return self._function_with_leading_comment_after_optional_decorators(updated_node, MODIFY_SCHEMA_COMMENT)

@m.leave(GET_VALIDATORS_FUNCTION)
def leave_get_validators_func(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
if self._has_comment:
self._has_comment = False
return updated_node

return self._function_with_leading_comment_after_optional_decorators(updated_node, GET_VALIDATORS_COMMENT)

def _function_with_leading_comment_after_optional_decorators(
self, node: cst.FunctionDef, comment: str
) -> cst.FunctionDef:
return node.with_changes(
lines_after_decorators=[
*node.lines_after_decorators,
cst.EmptyLine(comment=cst.Comment(value=(comment))),
cst.EmptyLine(comment=cst.Comment(value=(CHECK_LINK_COMMENT))),
]
)


if __name__ == "__main__":
import textwrap

from rich.console import Console

console = Console()

source = textwrap.dedent(
"""
class SomeThing:
@classmethod
def __get_validators__(cls):
yield from []
return

@classmethod
def __modify_schema__(
cls, field_schema: Dict[str, Any], field: Optional[ModelField]
):
if field:
field_schema['example'] = "Weird example"
"""
)
console.print(source)
console.print("=" * 80)

mod = cst.parse_module(source)
context = CodemodContext(filename="main.py")
wrapper = cst.MetadataWrapper(mod)
command = CustomTypeCodemod(context=context)
# console.print(mod)

mod = wrapper.visit(command)
wrapper = cst.MetadataWrapper(mod)
command = AddImportsVisitor(context=context) # type: ignore[assignment]
mod = wrapper.visit(command)
console.print(mod.code)
2 changes: 2 additions & 0 deletions tests/integration/cases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .base_settings import cases as base_settings_cases
from .con_func import cases as con_func_cases
from .config_to_model import cases as config_to_model_cases
from .custom_types import cases as custom_types_cases
from .field import cases as generic_model_cases
from .folder_inside_folder import cases as folder_inside_folder_cases
from .is_base_model import cases as is_base_model_cases
Expand All @@ -28,6 +29,7 @@
*folder_inside_folder_cases,
*unicode_cases,
*con_func_cases,
*custom_types_cases,
]
before = Folder("project", *[case.source for case in cases])
expected = Folder("project", *[case.expected for case in cases])
59 changes: 59 additions & 0 deletions tests/integration/cases/custom_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from ..case import Case
from ..file import File

cases = [
Case(
name="Mark __get_validators__",
source=File(
"mark_get_validators.py",
content=[
"class SomeThing:",
" @classmethod",
" def __get_validators__(cls):",
" yield from []",
" return",
],
),
expected=File(
"mark_get_validators.py",
content=[
"class SomeThing:",
" @classmethod",
" # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually.", # noqa: E501
" # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.",
" def __get_validators__(cls):",
" yield from []",
" return",
],
),
),
Case(
name="Mark __modify_schema__",
source=File(
"mark_modify_schema.py",
content=[
"class SomeThing:",
" @classmethod",
" def __modify_schema__(",
" cls, field_schema: Dict[str, Any], field: Optional[ModelField]",
" ):",
" if field:",
" field_schema['example'] = \"Weird example\"",
],
),
expected=File(
"mark_modify_schema.py",
content=[
"class SomeThing:",
" @classmethod",
" # TODO[pydantic]: We couldn't refactor `__modify_schema__`, please create the `__get_pydantic_json_schema__` manually.", # noqa: E501
" # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.",
" def __modify_schema__(",
" cls, field_schema: Dict[str, Any], field: Optional[ModelField]",
" ):",
" if field:",
" field_schema['example'] = \"Weird example\"",
],
),
),
]
51 changes: 51 additions & 0 deletions tests/unit/test_custom_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from libcst.codemod import CodemodTest

from bump_pydantic.codemods.custom_types import CustomTypeCodemod


class TestArbitraryClassCommand(CodemodTest):
TRANSFORM = CustomTypeCodemod

maxDiff = None

def test_mark_get_validators(self) -> None:
before = """
class SomeThing:
@classmethod
def __get_validators__(cls):
yield from []
return
"""
after = """
class SomeThing:
@classmethod
# TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually.
# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.
def __get_validators__(cls):
yield from []
return
""" # noqa: E501
self.assertCodemod(before, after)

def test_mark_modify_schema(self) -> None:
before = """
class SomeThing:
@classmethod
def __modify_schema__(
cls, field_schema: Dict[str, Any], field: Optional[ModelField]
):
if field:
field_schema['example'] = "Weird example"
"""
after = """
class SomeThing:
@classmethod
# TODO[pydantic]: We couldn't refactor `__modify_schema__`, please create the `__get_pydantic_json_schema__` manually.
# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.
def __modify_schema__(
cls, field_schema: Dict[str, Any], field: Optional[ModelField]
):
if field:
field_schema['example'] = "Weird example"
""" # noqa: E501
self.assertCodemod(before, after)