diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..87dade7 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,26 @@ +name: CI + +on: + push: + branches: + - main + pull_request: {} + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: set up python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Install hatch + run: pip install hatch + + - uses: pre-commit/action@v3.0.0 + with: + extra_args: --all-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..1483e87 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: check-yaml + args: ['--unsafe'] + - id: check-toml + - id: end-of-file-fixer + - id: trailing-whitespace + +- repo: local + hooks: + - id: lint + name: Lint + entry: hatch run lint + types: [python] + language: system + pass_filenames: false diff --git a/README.md b/README.md index 4df4039..48de098 100644 --- a/README.md +++ b/README.md @@ -1,31 +1,58 @@ -# bump-pydantic +# Bump Pydantic ♻️ -[![PyPI - Version](https://img.shields.io/pypi/v/bump-pydantic.svg)](https://pypi.org/project/bump-pydantic) -[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/bump-pydantic.svg)](https://pypi.org/project/bump-pydantic) + + +Utility to bump pydantic from V1 to V2. ----- -**Table of Contents** +### Rules + +#### BP001: Replace imports + +- ✅ Replace `BaseSettings` from `pydantic` to `pydantic_settings`. +- ✅ Replace `Color` and `PaymentCardNumber` from `pydantic` to `pydantic_extra_types`. + +#### BP002: Add default `None` to `Optional[T]`, `Union[T, None]` and `Any` fields + +- ✅ Add default `None` to `Optional[T]` fields. + +The following code will be transformed: -- [bump-pydantic](#bump-pydantic) - - [Installation](#installation) - - [Usage](#usage) - - [License](#license) +```py +class User(BaseModel): + name: Optional[str] +``` -## Installation +Into: -```console -pip install bump-pydantic +```py +class User(BaseModel): + name: Optional[str] = None ``` -## Usage +#### BP003: Replace `Config` class by `model_config` + +- ✅ Replace `Config` class by `model_config = ConfigDict()`. + +The following code will be transformed: -You can run `bump-pydantic` from the command line: +```py +class User(BaseModel): + name: str -```console -bump-pydantic + class Config: + extra = 'forbid' ``` -## License +Into: + +```py +class User(BaseModel): + name: str + + model_config = ConfigDict(extra='forbid') +``` -`bump-pydantic` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license. +#### BP004: Replace `BaseModel` methods diff --git a/bump_pydantic/__main__.py b/bump_pydantic/__main__.py index f6c6b64..10ff24e 100644 --- a/bump_pydantic/__main__.py +++ b/bump_pydantic/__main__.py @@ -1,98 +1,4 @@ -import difflib -import os -import sys -import time -from pathlib import Path - -import libcst as cst -from libcst.codemod import CodemodContext -from libcst.helpers import calculate_module_and_package -from libcst.metadata import FullRepoManager, PositionProvider, ScopeProvider -from libcst_mypy import MypyTypeInferenceProvider -from typer import Argument, Option, Typer - -from bump_pydantic.transformers import gather_transformers - -app = Typer(help="Convert Pydantic from V1 to V2.") - - -@app.command() -def main( - package: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False), - diff: bool = Option(False, help="Show diff instead of applying changes."), - debug: bool = Option(False, help="Show debug logs."), - add_default_none: bool = True, - # NOTE: It looks like there are some issues with the libcst.codemod.RenameCommand. - # To replicate the issue: clone aiopenapi3, and run `python -m bump_pydantic aiopenapi3`. - # For that reason, the default is False. - rename_imports: bool = False, - rename_methods: bool = True, - replace_config_class: bool = True, - replace_config_parameters: bool = True, -) -> None: - # sourcery skip: hoist-similar-statement-from-if, simplify-len-comparison, swap-nested-ifs - files = [str(path.absolute()) for path in package.glob("**/*.py")] - - transformers = gather_transformers( - add_default_none=add_default_none, - rename_imports=rename_imports, - rename_methods=rename_methods, - replace_config_class=replace_config_class, - replace_config_parameters=replace_config_parameters, - ) - - cwd = os.getcwd() - providers = {MypyTypeInferenceProvider, ScopeProvider, PositionProvider} - metadata_manager = FullRepoManager(cwd, files, providers=providers) - print("Inferring types... This may take a while.") - metadata_manager.resolve_cache() - print("Types are inferred.") - - start_time = time.time() - - # TODO: We can run this in parallel - batch it into files / cores. - # We may need to run the resolve_cache() on each core - not sure. - for transformer in transformers: - for filename in files: - module_and_package = calculate_module_and_package(cwd, filename) - transform = transformer( - CodemodContext( - metadata_manager=metadata_manager, - filename=filename, - full_module_name=module_and_package.name, - full_package_name=module_and_package.package, - ) - ) - if debug: - print(f"Processing {filename} with {transform.__class__.__name__}") - - with open(filename) as fp: - old_code = fp.read() - - input_tree = cst.parse_module(old_code) - output_tree = transform.transform_module(input_tree) - - input_code = input_tree.code - output_code = output_tree.code - - if input_code != output_code: - if diff: - # TODO: Should be colored. - lines = difflib.unified_diff( - input_code.splitlines(keepends=True), - output_code.splitlines(keepends=True), - fromfile=filename, - tofile=filename, - ) - sys.stdout.writelines(lines) - else: - with open(filename, "w") as fp: - fp.write(output_tree.code) - - modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time] - if len(modified) > 0: - print(f"Refactored {len(modified)} files.") - +from bump_pydantic.main import app if __name__ == "__main__": app() diff --git a/bump_pydantic/codemods/__init__.py b/bump_pydantic/codemods/__init__.py new file mode 100644 index 0000000..0f2059b --- /dev/null +++ b/bump_pydantic/codemods/__init__.py @@ -0,0 +1,18 @@ +from typing import List, Type + +from libcst.codemod import ContextAwareTransformer +from libcst.codemod.visitors import AddImportsVisitor + +from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand +from bump_pydantic.codemods.replace_config import ReplaceConfigCodemod +from bump_pydantic.codemods.replace_imports import ReplaceImportsCodemod + + +def gather_codemods() -> List[Type[ContextAwareTransformer]]: + return [ + AddDefaultNoneCommand, + ReplaceConfigCodemod, + ReplaceImportsCodemod, + # AddImportsVisitor needs to be the last. + AddImportsVisitor, + ] diff --git a/bump_pydantic/commands/add_default_none.py b/bump_pydantic/codemods/add_default_none.py similarity index 51% rename from bump_pydantic/commands/add_default_none.py rename to bump_pydantic/codemods/add_default_none.py index 463a90c..28c8334 100644 --- a/bump_pydantic/commands/add_default_none.py +++ b/bump_pydantic/codemods/add_default_none.py @@ -2,10 +2,12 @@ import libcst as cst import libcst.matchers as m -from libcst._nodes.statement import AnnAssign, ClassDef from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand -from libcst_mypy import MypyTypeInferenceProvider -from mypy.nodes import TypeInfo +from libcst.metadata import FullyQualifiedNameProvider, QualifiedName + +from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor +from bump_pydantic.markers.find_base_model import CONTEXT_KEY as BASE_MODEL_CONTEXT_KEY +from bump_pydantic.markers.find_base_model import find_base_model class AddDefaultNoneCommand(VisitorBasedCodemodCommand): @@ -14,61 +16,54 @@ class AddDefaultNoneCommand(VisitorBasedCodemodCommand): Example:: # Before + ```py from pydantic import BaseModel class Foo(BaseModel): bar: Optional[str] baz: Union[str, None] qux: Any + ``` # After + ```py from pydantic import BaseModel class Foo(BaseModel): bar: Optional[str] = None baz: Union[str, None] = None qux: Any = None + ``` """ - METADATA_DEPENDENCIES = (MypyTypeInferenceProvider,) + METADATA_DEPENDENCIES = { + FullyQualifiedNameProvider, + } - def __init__(self, context: CodemodContext, class_name: str) -> None: + def __init__(self, context: CodemodContext) -> None: super().__init__(context) - self.class_name = class_name self.inside_base_model = False self.should_add_none = False - def visit_ClassDef(self, node: ClassDef) -> None: - for base in node.bases: - scope = self.get_metadata(MypyTypeInferenceProvider, base.value, None) - if scope is not None and isinstance(scope.mypy_type, TypeInfo): - self.inside_base_model = self._is_class_name_base_of_type_info( - self.class_name, scope.mypy_type - ) - - def _is_class_name_base_of_type_info( - self, class_name: str, type_info: TypeInfo - ) -> bool: - if type_info.fullname == class_name: - return True - return any( - self._is_class_name_base_of_type_info(class_name, base.type) - for base in type_info.bases - ) - - def leave_ClassDef( - self, original_node: ClassDef, updated_node: ClassDef - ) -> ClassDef: + def visit_ClassDef(self, node: cst.ClassDef) -> None: + fqn_set = self.get_metadata(FullyQualifiedNameProvider, node) + + if not fqn_set: + return None + + fqn: QualifiedName = next(iter(fqn_set)) # type: ignore + if fqn.name in self.context.scratch[BASE_MODEL_CONTEXT_KEY]: + self.inside_base_model = True + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: self.inside_base_model = False return updated_node - def visit_AnnAssign(self, node: AnnAssign) -> bool | None: + def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None: if m.matches( node.annotation.annotation, - m.Subscript( - m.Name("Optional") | m.Attribute(m.Name("typing"), m.Name("Optional")) - ) + m.Subscript(m.Name("Optional") | m.Attribute(m.Name("typing"), m.Name("Optional"))) | m.Subscript( m.Name("Union") | m.Attribute(m.Name("typing"), m.Name("Union")), slice=[ @@ -79,21 +74,15 @@ def visit_AnnAssign(self, node: AnnAssign) -> bool | None: ) | m.Name("Any") | m.Attribute(m.Name("typing"), m.Name("Any")) - # TODO: This can be recursive. + # TODO: This can be recursive. Can it? | m.BinaryOperation(operator=m.BitOr(), left=m.Name("None")) | m.BinaryOperation(operator=m.BitOr(), right=m.Name("None")), ): self.should_add_none = True return super().visit_AnnAssign(node) - def leave_AnnAssign( - self, original_node: AnnAssign, updated_node: AnnAssign - ) -> AnnAssign: - if ( - self.inside_base_model - and self.should_add_none - and updated_node.value is None - ): + def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign: + if self.inside_base_model and self.should_add_none and updated_node.value is None: updated_node = updated_node.with_changes(value=cst.Name("None")) self.inside_an_assign = False self.should_add_none = False @@ -103,11 +92,13 @@ def leave_AnnAssign( if __name__ == "__main__": import os import textwrap + from pathlib import Path from tempfile import TemporaryDirectory from libcst.metadata import FullRepoManager + from rich.pretty import pprint - with TemporaryDirectory() as tmpdir: + with TemporaryDirectory(dir=os.getcwd()) as tmpdir: package_dir = f"{tmpdir}/package" os.mkdir(package_dir) module_path = f"{package_dir}/a.py" @@ -117,22 +108,29 @@ def leave_AnnAssign( from pydantic import BaseModel class Foo(BaseModel): - bar: Optional[str] - baz: Union[str, None] - qux: Any + a: Optional[str] + + class Bar(Foo): + b: Optional[str] + c: Union[str, None] + d: Any + + foo = Foo(a="text") + foo.dict() """ ) - print(content) - print("=" * 80) f.write(content) - f.seek(0) - module = cst.parse_module(content) - mrg = FullRepoManager( - package_dir, {module_path}, providers={MypyTypeInferenceProvider} - ) - wrapper = mrg.get_metadata_wrapper_for_path(module_path) + module = str(Path(module_path).relative_to(tmpdir)) + mrg = FullRepoManager(tmpdir, {module}, providers={FullyQualifiedNameProvider}) + wrapper = mrg.get_metadata_wrapper_for_path(module) context = CodemodContext(wrapper=wrapper) - command = AddDefaultNoneCommand( - context=context, class_name="pydantic.main.BaseModel" - ) - print(wrapper.visit(command).code) + + command = ClassDefVisitor(context=context) + mod = wrapper.visit(command) + + find_base_model(context=context) + pprint(context.scratch) + + command = AddDefaultNoneCommand(context=context) # type: ignore[assignment] + mod = wrapper.visit(command) + print(mod.code) diff --git a/bump_pydantic/codemods/class_def_visitor.py b/bump_pydantic/codemods/class_def_visitor.py new file mode 100644 index 0000000..17e54e0 --- /dev/null +++ b/bump_pydantic/codemods/class_def_visitor.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from collections import defaultdict + +import libcst as cst +from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand +from libcst.metadata import FullyQualifiedNameProvider, QualifiedName + + +class ClassDefVisitor(VisitorBasedCodemodCommand): + METADATA_DEPENDENCIES = {FullyQualifiedNameProvider} + + CONTEXT_KEY = "class_def_visitor" + + def __init__(self, context: CodemodContext) -> None: + super().__init__(context) + self.module_fqn: None | QualifiedName = None + self.context.scratch.setdefault(self.CONTEXT_KEY, defaultdict(set)) + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + fqn_set = self.get_metadata(FullyQualifiedNameProvider, node) + + if not fqn_set: + return None + + fqn: QualifiedName = next(iter(fqn_set)) # type: ignore + for arg in node.bases: + base_fqn_set = self.get_metadata(FullyQualifiedNameProvider, arg.value) + + if not base_fqn_set: + return None + + base_fqn: QualifiedName = next(iter(base_fqn_set)) # type: ignore + # NOTE: Should I use the name or the QualifiedName? + self.context.scratch[self.CONTEXT_KEY][fqn.name].add(base_fqn.name) + + +if __name__ == "__main__": + import os + import textwrap + from pathlib import Path + from tempfile import TemporaryDirectory + + from libcst.metadata import FullRepoManager + from rich.pretty import pprint + + with TemporaryDirectory(dir=os.getcwd()) as tmpdir: + package_dir = f"{tmpdir}/package" + os.mkdir(package_dir) + module_path = f"{package_dir}/a.py" + with open(module_path, "w") as f: + content = textwrap.dedent( + """ + from pydantic import BaseModel + + class Foo(BaseModel): + a: str + + class Bar(Foo): + b: str + + foo = Foo(a="text") + foo.dict() + """ + ) + f.write(content) + module = str(Path(module_path).relative_to(tmpdir)) + mrg = FullRepoManager(tmpdir, {module}, providers={FullyQualifiedNameProvider}) + wrapper = mrg.get_metadata_wrapper_for_path(module) + context = CodemodContext(wrapper=wrapper) + command = ClassDefVisitor(context=context) + mod = wrapper.visit(command) + pprint(context.scratch[ClassDefVisitor.CONTEXT_KEY]) diff --git a/bump_pydantic/codemods/replace_config.py b/bump_pydantic/codemods/replace_config.py new file mode 100644 index 0000000..de228a5 --- /dev/null +++ b/bump_pydantic/codemods/replace_config.py @@ -0,0 +1,181 @@ +from typing import List + +import libcst as cst +from libcst import matchers as m +from libcst._nodes.module import Module +from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand +from libcst.codemod.visitors import AddImportsVisitor +from libcst.metadata import PositionProvider + +REMOVED_KEYS = [ + "allow_mutation", + "error_msg_templates", + "fields", + "getter_dict", + "smart_union", + "underscore_attrs_are_private", + "json_loads", + "json_dumps", + "json_encoders", + "copy_on_model_validation", + "post_init_call", +] +RENAMED_KEYS = { + "allow_population_by_field_name": "populate_by_name", + "anystr_lower": "str_to_lower", + "anystr_strip_whitespace": "str_strip_whitespace", + "anystr_upper": "str_to_upper", + "keep_untouched": "ignored_types", + "max_anystr_length": "str_max_length", + "min_anystr_length": "str_min_length", + "orm_mode": "from_attributes", + "schema_extra": "json_schema_extra", + "validate_all": "validate_default", +} +# TODO: The codemod should not replace `Config` in case of removed keys, right? + +base_model_with_config = m.ClassDef( + bases=[ + m.ZeroOrMore(), + m.Arg(), + m.ZeroOrMore(), + ], + body=m.IndentedBlock( + body=[ + m.ZeroOrMore(), + m.ClassDef(name=m.Name(value="Config"), bases=[]), + m.ZeroOrMore(), + ] + ), +) +base_model_with_config_child = m.ClassDef( + bases=[ + m.ZeroOrMore(), + m.Arg(), + m.ZeroOrMore(), + ], + body=m.IndentedBlock( + body=[ + m.ZeroOrMore(), + m.ClassDef(name=m.Name(value="Config"), bases=[m.AtLeastN(n=1)]), + m.ZeroOrMore(), + ] + ), +) + + +class ReplaceConfigCodemod(VisitorBasedCodemodCommand): + """Replace `Config` class by `ConfigDict` call.""" + + METADATA_DEPENDENCIES = (PositionProvider,) + + def __init__(self, context: CodemodContext) -> None: + super().__init__(context) + + self.inside_config_class = False + + self.config_args: List[cst.Arg] = [] + + @m.visit(m.ClassDef(name=m.Name(value="Config"))) + def visit_config_class(self, node: cst.ClassDef) -> None: + self.inside_config_class = True + + @m.leave(m.ClassDef(name=m.Name(value="Config"))) + def leave_config_class(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: + self.inside_config_class = False + return updated_node + + def visit_Assign(self, node: cst.Assign) -> None: + # NOTE: There's no need for the `leave_Assign`. + self.assign_value = node.value + + def visit_AssignTarget(self, node: cst.AssignTarget) -> None: + self.config_args.append( + cst.Arg( + keyword=node.target, # type: ignore[arg-type] + value=self.assign_value, + equal=cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ), + ) + ) + + def leave_Module(self, original_node: Module, updated_node: Module) -> Module: + return updated_node + + @m.leave(base_model_with_config_child) + def leave_config_class_child(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: + position = self.get_metadata(PositionProvider, original_node) + print("You'll need to manually replace the `Config` class to the `model_config` attribute.") + print( + "File: {filename}:-{start_line},{start_column}:{end_line},{end_column}".format( + filename=self.context.filename, + start_line=position.start.line, + start_column=position.start.column, + end_line=position.end.line, + end_column=position.end.column, + ) + ) + return updated_node + + @m.leave(base_model_with_config) + def leave_config_class_childless(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: + """Replace the `Config` class with a `model_config` attribute. + + Any class that contains a `Config` class will have that class replaced + with a `model_config` attribute. The `model_config` attribute will be + assigned a `ConfigDict` object with the same arguments as the attributes + from `Config` class. + """ + AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="ConfigDict") + block = cst.ensure_type(original_node.body, cst.IndentedBlock) + body = [ + cst.SimpleStatementLine( + body=[ + cst.Assign( + targets=[cst.AssignTarget(target=cst.Name("model_config"))], + value=cst.Call( + func=cst.Name("ConfigDict"), + args=self.config_args, + ), + ) + ], + ) + if m.matches(statement, m.ClassDef(name=m.Name(value="Config"))) + else statement + for statement in block.body + ] + self.config_args = [] + return updated_node.with_changes(body=updated_node.body.with_changes(body=body)) + + +if __name__ == "__main__": + import textwrap + + from rich.console import Console + + console = Console() + + source = textwrap.dedent( + """ + from pydantic import BaseModel + + class A(BaseModel): + class Config: + arbitrary_types_allowed = True + """ + ) + console.print(source) + console.print("=" * 80) + + mod = cst.parse_module(source) + context = CodemodContext(filename="main.py") + wrapper = cst.MetadataWrapper(mod) + command = ReplaceConfigCodemod(context=context) + + mod = wrapper.visit(command) + wrapper = cst.MetadataWrapper(mod) + command = AddImportsVisitor(context=context) # type: ignore[assignment] + mod = wrapper.visit(command) + console.print(mod.code) diff --git a/bump_pydantic/codemods/replace_imports.py b/bump_pydantic/codemods/replace_imports.py new file mode 100644 index 0000000..7baf634 --- /dev/null +++ b/bump_pydantic/codemods/replace_imports.py @@ -0,0 +1,150 @@ +""" +This codemod deals with the following cases: + +1. `from pydantic import BaseSettings` +2. `from pydantic.settings import BaseSettings` +3. `from pydantic import BaseSettings as ` +4. `from pydantic.settings import BaseSettings as ` # TODO: This is not working. +5. `import pydantic` -> `pydantic.BaseSettings` +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence + +import libcst as cst +import libcst.matchers as m +from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand +from libcst.codemod.visitors import AddImportsVisitor + +IMPORTS = { + "pydantic:BaseSettings": ("pydantic_settings", "BaseSettings"), + "pydantic.settings:BaseSettings": ("pydantic_settings", "BaseSettings"), + "pydantic:Color": ("pydantic_extra_types.color", "Color"), + "pydantic.color:Color": ("pydantic_extra_types.color", "Color"), + "pydantic:PaymentCardNumber": ("pydantic_extra_types.payment", "PaymentCardNumber"), + "pydantic.payment:PaymentCardBrand": ( + "pydantic_extra_types.payment", + "PaymentCardBrand", + ), + "pydantic.payment:PaymentCardNumber": ( + "pydantic_extra_types.payment", + "PaymentCardNumber", + ), +} + + +def resolve_module_parts(module_parts: list[str]) -> m.Attribute | m.Name: + if len(module_parts) == 1: + return m.Name(module_parts[0]) + if len(module_parts) == 2: + first, last = module_parts + return m.Attribute(value=m.Name(first), attr=m.Name(last)) + last_name = module_parts.pop() + attr = resolve_module_parts(module_parts) + return m.Attribute(value=attr, attr=m.Name(last_name)) + + +def get_import_from_from_str(import_str: str) -> m.ImportFrom: + """Converts a string like `pydantic:BaseSettings` to an `ImportFrom` node. + + Examples: + >>> get_import_from_from_str("pydantic:BaseSettings") + ImportFrom( + module=Name("pydantic"), + names=[ImportAlias(name=Name("BaseSettings"))], + ) + >>> get_import_from_from_str("pydantic.settings:BaseSettings") + ImportFrom( + module=Attribute(value=Name("pydantic"), attr=Name("settings")), + names=[ImportAlias(name=Name("BaseSettings"))], + ) + >>> get_import_from_from_str("a.b.c:d") + ImportFrom( + module=Attribute( + value=Attribute(value=Name("a"), attr=Name("b")), attr=Name("c") + ), + names=[ImportAlias(name=Name("d"))], + ) + """ + module, name = import_str.split(":") + module_parts = module.split(".") + module_node = resolve_module_parts(module_parts) + return m.ImportFrom( + module=module_node, + names=[m.ZeroOrMore(), m.ImportAlias(name=m.Name(value=name)), m.ZeroOrMore()], + ) + + +@dataclass +class ImportInfo: + import_from: m.ImportFrom + import_str: str + to_import_str: tuple[str, str] + + +IMPORT_INFOS = [ + ImportInfo( + import_from=get_import_from_from_str(import_str), + import_str=import_str, + to_import_str=to_import_str, + ) + for import_str, to_import_str in IMPORTS.items() +] +IMPORT_MATCH = m.OneOf(*[info.import_from for info in IMPORT_INFOS]) + + +class ReplaceImportsCodemod(VisitorBasedCodemodCommand): + @m.leave(IMPORT_MATCH) + def leave_replace_import(self, _: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom: + for import_info in IMPORT_INFOS: + if m.matches(updated_node, import_info.import_from): + aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore + # If multiple objects are imported in a single import statement, + # we need to remove only the one we're replacing. + AddImportsVisitor.add_needed_import(self.context, *import_info.to_import_str) + if len(updated_node.names) > 1: # type: ignore + names = [alias for alias in aliases if alias.name.value != import_info.to_import_str[-1]] + updated_node = updated_node.with_changes(names=names) + else: + return cst.RemoveFromParent() # type: ignore[return-value] + return updated_node + + +if __name__ == "__main__": + import textwrap + + from rich.console import Console + + console = Console() + + source = textwrap.dedent( + """ + from pydantic.settings import BaseSettings + from pydantic.color import Color + from pydantic.payment import PaymentCardNumber, PaymentCardBrand + from pydantic import Color + from pydantic import Color as Potato + + + class Potato(BaseSettings): + color: Color + payment: PaymentCardNumber + brand: PaymentCardBrand + potato: Potato + """ + ) + console.print(source) + console.print("=" * 80) + + mod = cst.parse_module(source) + context = CodemodContext(filename="main.py") + wrapper = cst.MetadataWrapper(mod) + command = ReplaceImportsCodemod(context=context) + console.print(mod) + + mod = wrapper.visit(command) + wrapper = cst.MetadataWrapper(mod) + command = AddImportsVisitor(context=context) # type: ignore[assignment] + mod = wrapper.visit(command) + console.print(mod.code) diff --git a/bump_pydantic/codemods/replace_methods.py b/bump_pydantic/codemods/replace_methods.py new file mode 100644 index 0000000..ab31b69 --- /dev/null +++ b/bump_pydantic/codemods/replace_methods.py @@ -0,0 +1,153 @@ +"""The codemod that replaces deprecated methods with their new counterparts. + +This codemod replaces the following methods: +- `dict` -> `model_dump` +- `json` -> `model_dump_json` +- `parse_obj` -> `model_validate` +- `construct` -> `model_construct` +- `copy` -> `model_copy` +- `schema` -> `model_json_schema` +- `validate` -> `model_validate` + +There are two cases this codemod handles: + +1. Known BaseModel subclasses: +```py +class A(BaseModel): + ... + +model = A() +model.dict() +``` + +2. Type annotation: +```py +def func(model: A): + model.dict() + +3. Known BaseModel instance by call inference: +```py +def func() -> A: + ... + +model = func() +model.dict() +``` + +4. Known BaseModel subclass imported from another module. +```py +from project.add_none import A + +model = A() +model.dict() +``` + +5. Known instance imported from another module. +```py +from project import model + +model.dict() +``` +""" + + +from __future__ import annotations + +import libcst as cst +import libcst.matchers as m +from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand +from libcst.codemod.visitors import AddImportsVisitor +from libcst.metadata import QualifiedNameProvider, ScopeProvider + +# NOTE: Unsure what do to with the following methods: +# - parse_raw +# - parse_file +# - from_orm +# - schema_json + +DEPRECATED_METHODS = { + "dict": "model_dump", + "json": "model_dump_json", + "parse_obj": "model_validate", + "construct": "model_construct", + "copy": "model_copy", + "schema": "model_json_schema", + "validate": "model_validate", +} + +MATCH_DEPRECATED_METHODS = m.Call( + func=m.Attribute(attr=m.Name(value=m.MatchIfTrue(lambda value: value in DEPRECATED_METHODS))) +) + + +class ReplaceMethodsCodemod(VisitorBasedCodemodCommand): + METADATA_DEPENDENCIES = (ScopeProvider, QualifiedNameProvider) + + def visit_AssignTarget(self, node: cst.AssignTarget) -> bool | None: + print(node) + return super().visit_AssignTarget(node) + + def visit_Assign(self, node: cst.Assign) -> bool | None: + print(node) + return super().visit_Assign(node) + + # TODO: Add a warning in case you find a method that matches the rules, but it's not + # identified as a BaseModel instance. + @m.leave(MATCH_DEPRECATED_METHODS) + def leave_deprecated_methods(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: + self.get_metadata(QualifiedNameProvider, original_node) + # print("hi") + self.get_metadata(ScopeProvider, original_node) + # if isinstance(scope, GlobalScope): + # print(scope.globals) + # print(scope.parent) + + # for assignment in scope.assignments: + # if isinstance(assignment, Assignment): + # print(assignment.name) + # print(assignment.references) + # print(assignment.node) + # print() + # scope.get_qualified_names_for(original_node) + # print(scope.get_qualified_names_for(original_node)) + # for assignment in scope.assignments: + # print(assignment.name) + # print(assignment.references) + # print() + return updated_node + + +if __name__ == "__main__": + import textwrap + + from rich.console import Console + + console = Console() + + source = textwrap.dedent( + """ + from pydantic import BaseModel + + class A(BaseModel): + a: int + + class B(A): + b: int + + model = B(a=1, b=2) + model.dict() + """ + ) + console.print(source) + console.print("=" * 80) + + mod = cst.parse_module(source) + context = CodemodContext(filename="main.py") + wrapper = cst.MetadataWrapper(mod) + command = ReplaceMethodsCodemod(context=context) + + mod = wrapper.visit(command) + wrapper = cst.MetadataWrapper(mod) + command = AddImportsVisitor(context=context) # type: ignore[assignment] + mod = wrapper.visit(command) + console.print(mod.code) diff --git a/bump_pydantic/commands/rename_method_call.py b/bump_pydantic/commands/rename_method_call.py deleted file mode 100644 index 2fce57e..0000000 --- a/bump_pydantic/commands/rename_method_call.py +++ /dev/null @@ -1,116 +0,0 @@ -from __future__ import annotations - -import libcst as cst -from typing import cast -from libcst import matchers as m -from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand -from libcst_mypy import MypyTypeInferenceProvider -from mypy.nodes import TypeInfo -from mypy.types import Instance - - -class RenameMethodCallCommand(VisitorBasedCodemodCommand): - """This codemod renames a method call of a class. - - Example:: - # Given the following class and method mapping: - # class_name = "pydantic.main.BaseModel" - # methods = {"dict": "model_dump"} - - # Before - - from pydantic import BaseModel - - class Foo(BaseModel): - bar: str - - foo = Foo(bar="text") - foo.dict() - - # After - - from pydantic import BaseModel - - class Foo(BaseModel): - bar: str - - foo = Foo(bar="text") - foo.model_dump() - """ - - METADATA_DEPENDENCIES = (MypyTypeInferenceProvider,) - - def __init__( - self, - context: CodemodContext, - class_name: str, - methods: dict[str, str], - ) -> None: - super().__init__(context) - self.class_name = class_name - self.methods = methods - - def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: - node = original_node - if m.matches(node.func, m.Attribute()): - func = cst.ensure_type(node.func, cst.Attribute) - scope = self.get_metadata(MypyTypeInferenceProvider, func.value, None) - if scope is not None: - mypy_type = scope.mypy_type - if isinstance(mypy_type, Instance): - info = mypy_type.type - if self._is_class_name_base_of_type_info(self.class_name, info): - new_method = self.methods.get(func.attr.value) - if new_method is not None: - attr = func.attr.with_changes(value=new_method) - func_with_changes = func.with_changes(attr=attr) - return updated_node.with_changes(func=func_with_changes) - return updated_node - - def _is_class_name_base_of_type_info( - self, class_name: str, type_info: TypeInfo - ) -> bool: - if type_info.fullname == class_name: - return True - return any( - self._is_class_name_base_of_type_info(class_name, base.type) - for base in type_info.bases - ) - - -if __name__ == "__main__": - import os - import textwrap - from tempfile import TemporaryDirectory - - from libcst.metadata import FullRepoManager - - with TemporaryDirectory() as tmpdir: - package_dir = f"{tmpdir}/package" - os.mkdir(package_dir) - module_path = f"{package_dir}/a.py" - with open(module_path, "w") as f: - content = textwrap.dedent( - """ - from pydantic import BaseModel - - class Foo(BaseModel): - bar: str - - foo = Foo(bar="text") - foo.dict() - """ - ) - f.write(content) - f.seek(0) - mrg = FullRepoManager( - package_dir, {module_path}, providers={MypyTypeInferenceProvider} - ) - wrapper = mrg.get_metadata_wrapper_for_path(module_path) - context = CodemodContext(wrapper=wrapper) - command = RenameMethodCallCommand( - context=context, - class_name="pydantic.main.BaseModel", - methods={"dict": "model_dump"}, - ) - print(wrapper.visit(command).code) diff --git a/bump_pydantic/commands/replace_call_param.py b/bump_pydantic/commands/replace_call_param.py deleted file mode 100644 index 0621203..0000000 --- a/bump_pydantic/commands/replace_call_param.py +++ /dev/null @@ -1,92 +0,0 @@ -"""Replace a parameter in a function call.""" -from __future__ import annotations - -import libcst as cst -import libcst.matchers as m -from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand -from libcst.metadata import ImportAssignment, ScopeProvider - - -class ReplaceCallParam(VisitorBasedCodemodCommand): - """Replace a parameter in a function call. - - We visit the call to check if it's one of the callers provided, and if so, - we visit the arguments to check if it's the old parameter. If so, we replace - it with the new parameter. - """ - - METADATA_DEPENDENCIES = (ScopeProvider,) - - def __init__( - self, - context: CodemodContext, - callers: tuple[str, ...], - params: dict[str, str], - ) -> None: - super().__init__(context) - self.callers = callers - self.params = params - - self.inside_caller = False - - def visit_Call(self, node: cst.Call) -> None: - scope = self.get_metadata(ScopeProvider, node) - if scope is None: - return - for assignment in scope.assignments: - if isinstance(assignment, ImportAssignment): - qualified_names = assignment.get_qualified_names_for(assignment.name) - exact_path = any(qn.name in self.callers for qn in qualified_names) - - # When the qualified_names don't have the object that is going to be - # used, we need to verify if the module is in the list of callers. - caller_modules = [caller.rsplit(".", 1)[0] for caller in self.callers] - module_match = any(qn.name in caller_modules for qn in qualified_names) - if exact_path or module_match: - self.inside_caller = True - - def leave_Call( - self, original_node: cst.Call, updated_node: cst.Call - ) -> cst.BaseExpression: - self.inside_caller = False - return updated_node - - def leave_Arg(self, original_node: cst.Arg, updated_node: cst.Arg) -> cst.Arg: - is_old_param = m.matches( - updated_node, - m.Arg(keyword=m.Name(m.MatchIfTrue(lambda x: x in self.params.keys()))), - ) - if self.inside_caller and is_old_param: - return updated_node.with_changes( - keyword=cst.Name(self.params[updated_node.keyword.value]) - ) - return updated_node - - -if __name__ == "__main__": - import textwrap - - from rich.console import Console - - console = Console() - - source = textwrap.dedent( - """ - from pydantic.config import ConfigDict as ConfigDicto - - ConfigDicto(potato="potato") - """ - ) - console.print(source) - console.print("=" * 80) - - mod = cst.parse_module(source) - context = CodemodContext(filename="test.py") - wrapper = cst.MetadataWrapper(mod) - command = ReplaceCallParam( - context=context, - callers=("pydantic.ConfigDict", "pydantic.config.ConfigDict"), - params={"potato": "param"}, - ) - mod = wrapper.visit(command) - console.print(mod.code) diff --git a/bump_pydantic/commands/replace_config_class.py b/bump_pydantic/commands/replace_config_class.py deleted file mode 100644 index 1b8de92..0000000 --- a/bump_pydantic/commands/replace_config_class.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import List - -import libcst as cst -from libcst import matchers as m -from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand -from libcst.metadata import PositionProvider -from libcst.codemod.visitors import AddImportsVisitor - - -base_model_with_config = m.ClassDef( - bases=[ - m.ZeroOrMore(), - m.Arg(), - m.ZeroOrMore(), - ], - body=m.IndentedBlock( - body=[ - m.ZeroOrMore(), - m.ClassDef(name=m.Name(value="Config"), bases=[]), - m.ZeroOrMore(), - ] - ), -) -base_model_with_config_child = m.ClassDef( - bases=[ - m.ZeroOrMore(), - m.Arg(), - m.ZeroOrMore(), - ], - body=m.IndentedBlock( - body=[ - m.ZeroOrMore(), - m.ClassDef(name=m.Name(value="Config"), bases=[m.AtLeastN(n=1)]), - m.ZeroOrMore(), - ] - ), -) - - -class ReplaceConfigClassByDict(VisitorBasedCodemodCommand): - """Replace `Config` class by `ConfigDict` call.""" - - METADATA_DEPENDENCIES = (PositionProvider,) - - def __init__(self, context: CodemodContext) -> None: - super().__init__(context) - self.config_args: List[cst.Arg] = [] - - @m.visit(m.ClassDef(name=m.Name(value="Config"))) - def visit_config_class(self, node: cst.ClassDef) -> None: - """Collect the arguments from the `Config` class.""" - for statement in node.body.body: - if m.matches(statement, m.SimpleStatementLine()): - statement = cst.ensure_type(statement, cst.SimpleStatementLine) - for child in statement.body: - if m.matches(child, m.Assign()): - assignment = cst.ensure_type(child, cst.Assign) - assign_target = cst.ensure_type( - assignment.targets[0], cst.AssignTarget - ) - keyword = cst.ensure_type(assign_target.target, cst.Name) - keyword = keyword.with_changes(value=keyword.value) - arg = cst.Arg( - value=assignment.value, - keyword=keyword, - equal=cst.AssignEqual( - whitespace_before=cst.SimpleWhitespace(""), - whitespace_after=cst.SimpleWhitespace(""), - ), - ) - self.config_args.append(arg) - - @m.leave(base_model_with_config_child) - def leave_config_class_child( - self, original_node: cst.ClassDef, updated_node: cst.ClassDef - ) -> cst.ClassDef: - position = self.get_metadata(PositionProvider, original_node) - print( - "You'll need to manually replace the `Config` class to the `model_config` attribute." - ) - print(f"File: {self.context.filename}:-{position.start.line},{position.start.column}:{position.end.line},{position.end.column}") - return updated_node - - @m.leave(base_model_with_config) - def leave_config_class( - self, original_node: cst.ClassDef, updated_node: cst.ClassDef - ) -> cst.ClassDef: - """Replace the `Config` class with a `model_config` attribute. - - Any class that contains a `Config` class will have that class replaced - with a `model_config` attribute. The `model_config` attribute will be - assigned a `ConfigDict` object with the same arguments as the attributes - from `Config` class. - """ - AddImportsVisitor.add_needed_import( - context=self.context, module="pydantic", obj="ConfigDict" - ) - block = cst.ensure_type(original_node.body, cst.IndentedBlock) - body = [ - cst.SimpleStatementLine( - body=[ - cst.Assign( - targets=[cst.AssignTarget(target=cst.Name("model_config"))], - value=cst.Call( - func=cst.Name("ConfigDict"), - args=self.config_args, - ), - ) - ], - ) - if m.matches(statement, m.ClassDef(name=m.Name(value="Config"))) - else statement - for statement in block.body - ] - self.config_args = [] - return updated_node.with_changes(body=updated_node.body.with_changes(body=body)) diff --git a/bump_pydantic/commands/use_settings.py b/bump_pydantic/commands/use_settings.py deleted file mode 100644 index 08e6d2a..0000000 --- a/bump_pydantic/commands/use_settings.py +++ /dev/null @@ -1,24 +0,0 @@ -from libcst.codemod import CodemodContext -from libcst.codemod.commands.rename import RenameCommand - - -def UsePydanticSettingsCommand(context: CodemodContext): - """Support for pydantic.BaseSettings. - - This command will rename pydantic.BaseSettings to pydantic_settings:BaseSettings. - - It doesn't support the following cases: - - from pydantic.settings import BaseSettings - - import pydantic ... class Settings(pydantic.BaseSettings) - - import pydantic as pd ... class Settings(pd.BaseSettings) - - TODO: Support the above cases. To implement the above, you'll need to go to each - `ClassDef`, and see the bases. If there's a `pydantic.settings.BaseSettings` in the - bases, then you'll need to use `RemoveImportsVisitor` and `AddImportsVisitor` from - `libcst.codemod.visitors`. - """ - return RenameCommand( - context=context, - old_name="pydantic.BaseSettings", - new_name="pydantic_settings:BaseSettings", - ) diff --git a/bump_pydantic/main.py b/bump_pydantic/main.py new file mode 100644 index 0000000..da18af8 --- /dev/null +++ b/bump_pydantic/main.py @@ -0,0 +1,107 @@ +import difflib +import os +import sys +import time +from pathlib import Path +from typing import Any, Dict + +import libcst as cst +from libcst.codemod import CodemodContext +from libcst.helpers import calculate_module_and_package +from libcst.metadata import ( + FullRepoManager, + FullyQualifiedNameProvider, + PositionProvider, + ScopeProvider, +) +from typer import Argument, Exit, Option, Typer, echo + +from bump_pydantic import __version__ +from bump_pydantic.codemods import gather_codemods +from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor +from bump_pydantic.markers.find_base_model import find_base_model + +app = Typer(help="Convert Pydantic from V1 to V2 ♻️", invoke_without_command=True) + + +def version_callback(value: bool): + if value: + echo(f"bump-pydantic version: {__version__}") + raise Exit() + + +@app.callback() +def main( + package: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False), + diff: bool = Option(False, help="Show diff instead of applying changes."), + version: bool = Option(None, "--version", callback=version_callback, is_eager=True), +): + cwd = os.getcwd() + files_str = [path.absolute() for path in package.glob("**/*.py")] + files = [str(file.relative_to(cwd)) for file in files_str] + + providers = {ScopeProvider, PositionProvider, FullyQualifiedNameProvider} + metadata_manager = FullRepoManager(cwd, files, providers=providers) # type: ignore[arg-type] + metadata_manager.resolve_cache() + + scratch: Dict[str, Any] = {} + for filename in files: + code = Path(filename).read_text() + module = cst.parse_module(code) + module_and_package = calculate_module_and_package(str(package), filename) + + context = CodemodContext( + metadata_manager=metadata_manager, + filename=filename, + full_module_name=module_and_package.name, + full_package_name=module_and_package.package, + scratch=scratch, + ) + visitor = ClassDefVisitor(context=context) + visitor.transform_module(module) + scratch = context.scratch + + find_base_model(scratch["classes"]) + + start_time = time.time() + + codemods = gather_codemods() + + # TODO: We can run this in parallel - batch it into files / cores. + # We may need to run the resolve_cache() on each core - not sure. + for codemod in codemods: + for filename in files: + module_and_package = calculate_module_and_package(str(package), filename) + transformer = codemod( + CodemodContext( + metadata_manager=metadata_manager, + filename=filename, + full_module_name=module_and_package.name, + full_package_name=module_and_package.package, + ) + ) + + old_code = Path(filename).read_text() + input_tree = cst.parse_module(old_code) + output_tree = transformer.transform_module(input_tree) + + input_code = input_tree.code + output_code = output_tree.code + + if input_code != output_code: + if diff: + # TODO: Should be colored. + lines = difflib.unified_diff( + input_code.splitlines(keepends=True), + output_code.splitlines(keepends=True), + fromfile=filename, + tofile=filename, + ) + sys.stdout.writelines(lines) + else: + with open(filename, "w") as fp: + fp.write(output_tree.code) + + modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time] + if modified: + print(f"Refactored {len(modified)} files.") diff --git a/bump_pydantic/commands/__init__.py b/bump_pydantic/markers/__init__.py similarity index 100% rename from bump_pydantic/commands/__init__.py rename to bump_pydantic/markers/__init__.py diff --git a/bump_pydantic/markers/find_base_model.py b/bump_pydantic/markers/find_base_model.py new file mode 100644 index 0000000..2fde4f3 --- /dev/null +++ b/bump_pydantic/markers/find_base_model.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from collections import defaultdict + +from libcst.codemod import CodemodContext + +from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor + +CONTEXT_KEY = "find_base_model" + + +def revert_dictionary(classes: defaultdict[str, set[str]]) -> defaultdict[str, set[str]]: + revert_classes: defaultdict[str, set[str]] = defaultdict(set) + for cls, bases in classes.copy().items(): + for base in bases: + revert_classes[base].add(cls) + return revert_classes + + +def find_base_model(context: CodemodContext) -> None: + classes = context.scratch[ClassDefVisitor.CONTEXT_KEY] + revert_classes = revert_dictionary(classes) + base_model_set: set[str] = set() + + for cls, bases in revert_classes.copy().items(): + if cls in ("pydantic.BaseModel", "BaseModel"): + base_model_set = base_model_set.union(bases) + + visited: set[str] = set() + bases_queue = list(bases) + while bases_queue: + base = bases_queue.pop() + + if base in visited: + continue + visited.add(base) + + base_model_set.add(base) + bases_queue.extend(revert_classes[base]) + + context.scratch[CONTEXT_KEY] = base_model_set + + +if __name__ == "__main__": + import os + import textwrap + from pathlib import Path + from tempfile import TemporaryDirectory + + from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider + from rich.pretty import pprint + + with TemporaryDirectory(dir=os.getcwd()) as tmpdir: + package_dir = f"{tmpdir}/package" + os.mkdir(package_dir) + module_path = f"{package_dir}/a.py" + with open(module_path, "w") as f: + content = textwrap.dedent( + """ + from pydantic import BaseModel + + class Foo(BaseModel): + a: str + + class Bar(Foo): + b: str + + foo = Foo(a="text") + foo.dict() + """ + ) + f.write(content) + module = str(Path(module_path).relative_to(tmpdir)) + mrg = FullRepoManager(tmpdir, {module}, providers={FullyQualifiedNameProvider}) + wrapper = mrg.get_metadata_wrapper_for_path(module) + context = CodemodContext(wrapper=wrapper) + command = ClassDefVisitor(context=context) + mod = wrapper.visit(command) + find_base_model(context=context) + pprint(context.scratch[CONTEXT_KEY]) diff --git a/bump_pydantic/transformers.py b/bump_pydantic/transformers.py deleted file mode 100644 index 8e9fd02..0000000 --- a/bump_pydantic/transformers.py +++ /dev/null @@ -1,105 +0,0 @@ -from __future__ import annotations - -from typing import Callable - -from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand -from libcst.codemod.commands.rename import RenameCommand - -from bump_pydantic.commands.add_default_none import AddDefaultNoneCommand -from bump_pydantic.commands.rename_method_call import RenameMethodCallCommand -from bump_pydantic.commands.replace_call_param import ReplaceCallParam -from bump_pydantic.commands.replace_config_class import ReplaceConfigClassByDict -from bump_pydantic.commands.use_settings import UsePydanticSettingsCommand - -CHANGED_IMPORTS = { - "pydantic.tools": "pydantic.deprecated.tools", - "pydantic.json": "pydantic.deprecated.json", - "pydantic.decorator": "pydantic.deprecated.decorator", - "pydantic.validate_arguments": "pydantic.deprecated.decorator:validate_arguments", - "pydantic.decorator.validate_arguments": "pydantic.deprecated:decorator.validate_arguments", -} - -CHANGED_METHODS = { - "dict": "model_dump", - "json": "model_dump_json", - "parse_obj": "model_validate", - "construct": "model_construct", - "schema": "model_json_schema", - "validate": "model_validate", - "update_forward_refs": "model_rebuild", -} - -CHANGED_CONFIG_PARAMS = { - "allow_population_by_field_name": "populate_by_name", - "anystr_lower": "str_to_lower", - "anystr_strip_whitespace": "str_strip_whitespace", - "anystr_upper": "str_to_upper", - "keep_untouched": "ignored_types", - "max_anystr_length": "str_max_length", - "min_anystr_length": "str_min_length", - "orm_mode": "from_attributes", - "validate_all": "validate_default", -} - - -def gather_transformers( - add_default_none: bool = True, - rename_imports: bool = True, - rename_methods: bool = True, - replace_config_class: bool = True, - replace_config_parameters: bool = True, -) -> list[Callable[[CodemodContext], VisitorBasedCodemodCommand]]: - """Gather all transformers to apply. - - Args: - add_default_none: Whether to add `None` to fields. - rename_imports: Whether to rename imports. - rename_methods: Whether to rename methods. - replace_config_class: Whether to replace `Config` class by `ConfigDict`. - replace_config_parameters: Whether to replace `Config` parameters by `ConfigDict` - parameters. - - Returns: - A list of transformers to apply. - """ - transformers: list[Callable[[CodemodContext], VisitorBasedCodemodCommand]] = [] - - if rename_methods: - transformers.append( - lambda context: RenameMethodCallCommand( - context=context, - class_name="pydantic.main.BaseModel", - methods=CHANGED_METHODS, - ) - ) - - if rename_imports: - # TODO: This can be a single transformer. - transformers.extend( - lambda context: RenameCommand(context, old_import, new_import) - for old_import, new_import in CHANGED_IMPORTS.items() - ) - # NOTE: Including this here, since there's an issue on RenameCommand, and - # UsePydanticSettingsCommand is just a wrapper - which could have been included - # on the list of changed imports above. - transformers.append(UsePydanticSettingsCommand) - - if add_default_none: - transformers.append( - lambda context: AddDefaultNoneCommand( - context=context, class_name="pydantic.main.BaseModel" - ) - ) - - if replace_config_class: - transformers.append(lambda context: ReplaceConfigClassByDict(context=context)) - - if replace_config_parameters: - transformers.append( - lambda context: ReplaceCallParam( - context=context, - callers=("pydantic.config.ConfigDict", "pydantic.ConfigDict"), - params=CHANGED_CONFIG_PARAMS, - ) - ) - return transformers diff --git a/project/rename_method.py b/project/rename_method.py index 77cd392..d969fc9 100644 --- a/project/rename_method.py +++ b/project/rename_method.py @@ -1,4 +1,4 @@ from project.add_none import A -a = A(a=1, b=2, c=3, d=4) +a = A(a=1, b=2, c=3, d=4, e={"ha": "ha"}) a.dict() diff --git a/pyproject.toml b/pyproject.toml index f16bc34..07bea8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "bump-pydantic" dynamic = ["version"] -description = '' +description = "Convert Pydantic from V1 to V2 ♻" readme = "README.md" requires-python = ">=3.8" license = "MIT" @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", "Framework :: Pydantic", ] -dependencies = ["typer>=0.7.0", "rich>=13.3.4", "libcst-mypy"] +dependencies = ["typer>=0.7.0", "libcst"] [project.urls] Documentation = "https://github.com/pydantic/bump-pydantic#readme" @@ -30,92 +30,50 @@ Issues = "https://github.com/pydantic/bump-pydantic/issues" Source = "https://github.com/pydantic/bump-pydantic" [project.scripts] -bump-pydantic = "bump_pydantic.__main__:app" +bump-pydantic = "bump_pydantic.main:app" [tool.hatch.version] path = "bump_pydantic/__init__.py" [tool.hatch.envs.default] -dependencies = ["coverage[toml]>=6.5", "pytest", "rich", "pydantic", "pytest-xdist"] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", + "rich", + "pydantic", + "pytest-xdist", + "rtoml", + "black>=23.1.0", + "mypy>=1.0.0", + "ruff>=0.0.243", +] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" cov-report = ["- coverage combine", "coverage report"] cov = ["test-cov", "cov-report"] +lint = [ + "ruff {args:.}", + "black --check --diff {args:.}", + "mypy {args:bump_pydantic tests}", +] +format = ["black {args:.}", "ruff --fix {args:.}"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11"] -[tool.hatch.envs.lint] -detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] - -[tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:bump_pydantic tests}" -style = ["ruff {args:.}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] -all = ["style", "typing"] - [tool.black] target-version = ["py38"] skip-string-normalization = true +line-length = 120 [tool.ruff] -target-version = "py38" -line-length = 88 -select = [ - "A", - "ARG", - "B", - "C", - "DTZ", - "E", - "F", - "I", - "ICN", - "ISC", - "PLC", - "PLE", - "PLR", - "PLW", - "RUF", - "S", - "T", - "TID", - "UP", - "W", - "YTT", -] -ignore = [ - # Allow non-abstract empty methods in abstract base classes - "B027", - "B008", - # Ignore checks for possible passwords - "S105", - "S106", - "S107", - # Ignore complexity - "C901", - "PLR0911", - "PLR0912", - "PLR0913", - "PLR0915", -] -unfixable = [ - # Don't touch unused imports - "F401", -] - -[tool.ruff.isort] -known-first-party = ["bump_pydantic"] - -[tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" - -[tool.ruff.per-file-ignores] -# Tests can use magic values, assertions, and relative imports -"tests/**/*" = ["PLR2004", "S101", "TID252"] +line-length = 120 +extend-select = ['Q', 'RUF100', 'C90', 'UP', 'I'] +mccabe = { max-complexity = 14 } +isort = { known-first-party = ['bump_pydantic', 'tests'] } +target-version = 'py38' [tool.coverage.run] source_pkgs = ["bump_pydantic", "tests"] @@ -127,5 +85,5 @@ skip_covered = true exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [tool.coverage.paths] -bump_pydantic = ["bump_pydantic", "*/bump-pydantic/bump_pydantic"] -tests = ["tests", "*/bump-pydantic/tests"] +source = ["bump_pydantic/"] +detached = true diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index c6fc431..0000000 --- a/tests/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present Marcelo Trylesinski -# -# SPDX-License-Identifier: MIT diff --git a/tests/commands/test_add_default_none.py b/tests/commands/test_add_default_none.py deleted file mode 100644 index 210ea0f..0000000 --- a/tests/commands/test_add_default_none.py +++ /dev/null @@ -1,212 +0,0 @@ -import textwrap -from pathlib import Path - -import libcst as cst -import pytest -from libcst.codemod import CodemodContext -from libcst.metadata import MetadataWrapper -from libcst_mypy import MypyTypeInferenceProvider - -from bump_pydantic.commands.add_default_none import AddDefaultNoneCommand - - -@pytest.mark.parametrize( - "source, output", - [ - pytest.param( - """ - from typing import Optional - - from pydantic import BaseModel - - - class Foo(BaseModel): - bar: Optional[str] - """, - """ - from typing import Optional - - from pydantic import BaseModel - - - class Foo(BaseModel): - bar: Optional[str] = None - """, - id="optional", - ), - pytest.param( - """ - from typing import Dict - - from pydantic import BaseModel - - - class Foo(BaseModel): - bar: Dict[str, str] - """, - """ - from typing import Dict - - from pydantic import BaseModel - - - class Foo(BaseModel): - bar: Dict[str, str] - """, - id="dict", - ), - pytest.param( - """ - import typing - from pydantic import BaseModel - - class Foo(BaseModel): - bar: typing.Optional[str] - """, - """ - import typing - from pydantic import BaseModel - - class Foo(BaseModel): - bar: typing.Optional[str] = None - """, - id="typing.optional", - ), - pytest.param( - """ - from typing import Optional - - from pydantic import BaseModel - - class Foo(BaseModel): - bar: str | None - """, - """ - from typing import Optional - - from pydantic import BaseModel - - class Foo(BaseModel): - bar: str | None = None - """, - id="pipe optional", - ), - pytest.param( - """ - from typing import Optional - - from pydantic import BaseModel - - class Foo(BaseModel): - bar: str | None | int - """, - """ - from typing import Optional - - from pydantic import BaseModel - - class Foo(BaseModel): - bar: str | None | int = None - """, - id="multi pipe optional", - marks=pytest.mark.skip(reason="Not implemented"), - ), - pytest.param( - """ - from typing import Optional - - from pydantic import BaseModel - - class Foo(BaseModel): - bar: Union[str, None] - """, - """ - from typing import Optional - - from pydantic import BaseModel - - class Foo(BaseModel): - bar: Union[str, None] = None - """, - id="union optional", - ), - pytest.param( - """ - from typing import Any - - from pydantic import BaseModel - - class Foo(BaseModel): - bar: Any - """, - """ - from typing import Any - - from pydantic import BaseModel - - class Foo(BaseModel): - bar: Any = None - """, - id="any", - ), - pytest.param( - """ - from typing import Any - - from pydantic import BaseModel - - class Foo(BaseModel): - ... - - class Bar(Foo): - bar: Any - """, - """ - from typing import Any - - from pydantic import BaseModel - - class Foo(BaseModel): - ... - - class Bar(Foo): - bar: Any = None - """, - id="inheritance", - ), - pytest.param( - """ - from typing import Any - - class Foo: - bar: Any - """, - """ - from typing import Any - - class Foo: - bar: Any - """, - id="not pydantic", - ), - ], -) -def test_add_default_none(source: str, output: str, tmp_path: Path) -> None: - package = tmp_path / "package" - package.mkdir() - - source_path = package / "a.py" - source_path.write_text(textwrap.dedent(source)) - - file = str(source_path) - cache = MypyTypeInferenceProvider.gen_cache(package, [file]) - wrapper = MetadataWrapper( - cst.parse_module(source_path.read_text()), - cache={MypyTypeInferenceProvider: cache[file]}, - ) - module = wrapper.visit( - AddDefaultNoneCommand( - context=CodemodContext(), class_name="pydantic.main.BaseModel" - ) - ) - assert module.code == textwrap.dedent(output) diff --git a/tests/commands/test_pydantic_settings.py b/tests/commands/test_pydantic_settings.py deleted file mode 100644 index 859b998..0000000 --- a/tests/commands/test_pydantic_settings.py +++ /dev/null @@ -1,71 +0,0 @@ -import pytest -from libcst.codemod import CodemodTest - -from bump_pydantic.commands.use_settings import UsePydanticSettingsCommand - - -class TestUsePydanticSettingsCommand(CodemodTest): - TRANSFORM = lambda _, context: UsePydanticSettingsCommand(context) - - def test_base_settings(self): - before = """ - from pydantic import BaseSettings - - class Settings(BaseSettings): - foo: str - """ - after = """ - from pydantic_settings import BaseSettings - - class Settings(BaseSettings): - foo: str - """ - self.assertCodemod(before, after) - - @pytest.mark.skip(reason="Not implemented yet") - def test_base_settings_import(self): - before = """ - from pydantic.settings import BaseSettings - - class Settings(BaseSettings): - foo: str - """ - after = """ - from pydantic_settings import BaseSettings - - class Settings(BaseSettings): - foo: str - """ - self.assertCodemod(before, after) - - @pytest.mark.skip(reason="Not implemented yet") - def test_base_settings_import_from(self): - before = """ - import pydantic - - class Settings(pydantic.BaseSettings): - foo: str - """ - after = """ - from pydantic_settings import BaseSettings - - class Settings(BaseSettings): - foo: str - """ - self.assertCodemod(before, after) - - @pytest.mark.skip(reason="Not implemented yet") - def test_base_settings_import_from_alias(self): - before = """ - import pydantic as pd - - class Settings(pd.BaseSettings): - foo: str - """ - after = """ - from pydantic_settings import BaseSettings - - class Settings(BaseSettings): - foo: str - """ - self.assertCodemod(before, after) diff --git a/tests/commands/test_rename_method_call.py b/tests/commands/test_rename_method_call.py deleted file mode 100644 index f6ffa91..0000000 --- a/tests/commands/test_rename_method_call.py +++ /dev/null @@ -1,105 +0,0 @@ -from __future__ import annotations - -import textwrap -from pathlib import Path - -import libcst as cst -import pytest -from libcst.codemod import CodemodContext -from libcst.metadata import MetadataWrapper -from libcst_mypy import MypyTypeInferenceProvider - -from bump_pydantic.commands.rename_method_call import RenameMethodCallCommand - - -@pytest.mark.parametrize( - "source, output", - [ - pytest.param( - """ - from pydantic import BaseModel - - class Foo(BaseModel): - bar: str - - foo = Foo(bar="text") - foo.dict() - """, - """ - from pydantic import BaseModel - - class Foo(BaseModel): - bar: str - - foo = Foo(bar="text") - foo.model_dump() - """, - id="dict", - ), - pytest.param( - """ - class Foo: - bar: str - - foo = Foo(bar="text") - foo.dict() - """, - """ - class Foo: - bar: str - - foo = Foo(bar="text") - foo.dict() - """, - id="dict_no_inheritance", - ), - pytest.param( - """ - from pydantic import BaseModel - - class Foo(BaseModel): - foo: str - - class Bar(Foo): - bar: str - - bar = Bar(foo="text", bar="text") - bar.dict() - """, - """ - from pydantic import BaseModel - - class Foo(BaseModel): - foo: str - - class Bar(Foo): - bar: str - - bar = Bar(foo="text", bar="text") - bar.model_dump() - """, - id="dict_inherited", - ), - ], -) -def test_rename_method_call(source: str, output: str, tmp_path: Path) -> None: - package = tmp_path / "package" - package.mkdir() - - source_path = package / "a.py" - source_path.write_text(textwrap.dedent(source)) - - file = str(source_path) - cache = MypyTypeInferenceProvider.gen_cache(package, [file]) - wrapper = MetadataWrapper( - cst.parse_module(source_path.read_text()), - cache={MypyTypeInferenceProvider: cache[file]}, - ) - module = wrapper.visit( - RenameMethodCallCommand( - context=CodemodContext(wrapper=wrapper), - class_name="pydantic.main.BaseModel", - methods={"dict": "model_dump"}, - ) - ) - assert module.code == textwrap.dedent(output) diff --git a/tests/commands/test_replace_call_param.py b/tests/commands/test_replace_call_param.py deleted file mode 100644 index 018d7e5..0000000 --- a/tests/commands/test_replace_call_param.py +++ /dev/null @@ -1,105 +0,0 @@ -from __future__ import annotations - -import textwrap -from pathlib import Path - -import libcst as cst -import pytest -from libcst.codemod import CodemodContext -from libcst.metadata import MetadataWrapper -from libcst_mypy import MypyTypeInferenceProvider - -from bump_pydantic.commands.replace_call_param import ReplaceCallParam - - -@pytest.mark.parametrize( - "source, output", - [ - pytest.param( - """ - from pydantic import ConfigDict - - ConfigDict(kwarg="potato") - """, - """ - from pydantic import ConfigDict - - ConfigDict(param="potato") - """, - id="simple", - ), - pytest.param( - """ - from pydantic import ConfigDict as ConfigDicto - - ConfigDicto(kwarg="potato") - """, - """ - from pydantic import ConfigDict as ConfigDicto - - ConfigDicto(param="potato") - """, - id="alias", - ), - pytest.param( - """ - from pydantic import config - - config.ConfigDict(kwarg="potato") - """, - """ - from pydantic import config - - config.ConfigDict(param="potato") - """, - id="from", - ), - pytest.param( - """ - from pydantic import config as configo - - configo.ConfigDict(kwarg="potato") - """, - """ - from pydantic import config as configo - - configo.ConfigDict(param="potato") - """, - id="from_alias", - ), - pytest.param( - """ - import pydantic - - pydantic.ConfigDict(kwarg="potato") - """, - """ - import pydantic - - pydantic.ConfigDict(param="potato") - """, - id="import", - ), - ], -) -def test_replace_call_param(source: str, output: str, tmp_path: Path) -> None: - package = tmp_path / "package" - package.mkdir() - - source_path = package / "a.py" - source_path.write_text(textwrap.dedent(source)) - - file = str(source_path) - cache = MypyTypeInferenceProvider.gen_cache(package, [file]) - wrapper = MetadataWrapper( - cst.parse_module(source_path.read_text()), - cache={MypyTypeInferenceProvider: cache[file]}, - ) - module = wrapper.visit( - ReplaceCallParam( - context=CodemodContext(wrapper=wrapper), - callers=("pydantic.config.ConfigDict", "pydantic.ConfigDict"), - params={"kwarg": "param"}, - ) - ) - assert module.code == textwrap.dedent(output) diff --git a/tests/test_add_default_none.py b/tests/test_add_default_none.py new file mode 100644 index 0000000..e79b075 --- /dev/null +++ b/tests/test_add_default_none.py @@ -0,0 +1,153 @@ +import textwrap +from pathlib import Path + +import libcst as cst +import pytest +from libcst import MetadataWrapper, parse_module +from libcst.codemod import CodemodContext, CodemodTest +from libcst.metadata import FullyQualifiedNameProvider +from libcst.testing.utils import UnitTest + +from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand +from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor +from bump_pydantic.markers.find_base_model import find_base_model + + +class TestClassDefVisitor(UnitTest): + def add_default_none(self, file_path: str, code: str) -> cst.Module: + mod = MetadataWrapper( + parse_module(CodemodTest.make_fixture_data(code)), + cache={ + FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(Path(""), [file_path], None).get( + file_path, "" + ) + }, + ) + mod.resolve_many(AddDefaultNoneCommand.METADATA_DEPENDENCIES) + context = CodemodContext(wrapper=mod) + instance = ClassDefVisitor(context=context) + mod.visit(instance) + + find_base_model(context=context) + + instance = AddDefaultNoneCommand(context=context) # type: ignore[assignment] + return mod.visit(instance) + + def test_no_annotations(self) -> None: + source = textwrap.dedent( + """class Potato: + a: Optional[str] + """ + ) + module = self.add_default_none("some/test/module.py", source) + self.assertEqual(module.code, source) + + def test_with_optional(self) -> None: + module = self.add_default_none( + "some/test/module.py", + """ + from pydantic import BaseModel + + class Potato(BaseModel): + a: Optional[str] + """, + ) + expected = textwrap.dedent( + """from pydantic import BaseModel + +class Potato(BaseModel): + a: Optional[str] = None +""" + ) + self.assertEqual(module.code, expected) + + def test_with_union_none(self) -> None: + module = self.add_default_none( + "some/test/module.py", + """ + from pydantic import BaseModel + from typing import Union + + class Potato(BaseModel): + a: Union[str, None] + """, + ) + expected = textwrap.dedent( + """from pydantic import BaseModel +from typing import Union + +class Potato(BaseModel): + a: Union[str, None] = None +""" + ) + self.assertEqual(module.code, expected) + + def test_with_multiple_classes(self) -> None: + module = self.add_default_none( + "some/test/module.py", + """ + from pydantic import BaseModel + from typing import Optional + + class Potato(BaseModel): + a: Optional[str] + + class Carrot(Potato): + b: Optional[str] + """, + ) + expected = textwrap.dedent( + """from pydantic import BaseModel +from typing import Optional + +class Potato(BaseModel): + a: Optional[str] = None + +class Carrot(Potato): + b: Optional[str] = None + """ + ) + self.assertEqual(module.code, expected) + + def test_any(self) -> None: + module = self.add_default_none( + "some/test/module.py", + """ + from pydantic import BaseModel + from typing import Any + + class Potato(BaseModel): + a: Any + """, + ) + expected = textwrap.dedent( + """from pydantic import BaseModel +from typing import Any + +class Potato(BaseModel): + a: Any = None +""" + ) + self.assertEqual(module.code, expected) + + @pytest.mark.xfail(reason="Recursive Union is not supported") + def test_union_of_union(self) -> None: + module = self.add_default_none( + "some/test/module.py", + """ + from pydantic import BaseModel + from typing import Union + + class Potato(BaseModel): + a: Union[Union[str, None], int] + """, + ) + expected = textwrap.dedent( + """from pydantic import BaseModel +from typing import Union + +class Potato(BaseModel): + a: Union[Union[str, None], int] = None +""" + ) + self.assertEqual(module.code, expected) diff --git a/tests/test_class_def_visitor.py b/tests/test_class_def_visitor.py new file mode 100644 index 0000000..1065780 --- /dev/null +++ b/tests/test_class_def_visitor.py @@ -0,0 +1,84 @@ +from pathlib import Path + +from libcst import MetadataWrapper, parse_module +from libcst.codemod import CodemodContext, CodemodTest +from libcst.metadata import FullyQualifiedNameProvider +from libcst.testing.utils import UnitTest + +from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor + + +class TestClassDefVisitor(UnitTest): + def gather_class_def(self, file_path: str, code: str) -> ClassDefVisitor: + mod = MetadataWrapper( + parse_module(CodemodTest.make_fixture_data(code)), + cache={ + FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(Path(""), [file_path], None).get( + file_path, "" + ) + }, + ) + mod.resolve_many(ClassDefVisitor.METADATA_DEPENDENCIES) + instance = ClassDefVisitor(CodemodContext(wrapper=mod)) + mod.visit(instance) + return instance + + def test_no_annotations(self) -> None: + visitor = self.gather_class_def( + "some/test/module.py", + """ + def foo() -> None: + pass + """, + ) + results = visitor.context.scratch[ClassDefVisitor.CONTEXT_KEY] + self.assertEqual(results, {}) + + def test_without_bases(self) -> None: + visitor = self.gather_class_def( + "some/test/module.py", + """ + class Foo: + pass + """, + ) + results = visitor.context.scratch[ClassDefVisitor.CONTEXT_KEY] + self.assertEqual(results, {}) + + def test_with_class_defs(self) -> None: + visitor = self.gather_class_def( + "some/test/module.py", + """ + from pydantic import BaseModel + + class Foo(BaseModel): + pass + + class Bar(Foo): + pass + """, + ) + results = visitor.context.scratch[ClassDefVisitor.CONTEXT_KEY] + self.assertEqual( + results, + { + "some.test.module.Foo": {"pydantic.BaseModel"}, + "some.test.module.Bar": {"some.test.module.Foo"}, + }, + ) + + def test_with_pydantic_base_model(self) -> None: + visitor = self.gather_class_def( + "some/test/module.py", + """ + import pydantic + + class Foo(pydantic.BaseModel): + ... + """, + ) + results = visitor.context.scratch[ClassDefVisitor.CONTEXT_KEY] + self.assertEqual( + results, + {"some.test.module.Foo": {"pydantic.BaseModel"}}, + ) diff --git a/tests/test_integration.py b/tests/test_integration.py deleted file mode 100644 index 252b8c7..0000000 --- a/tests/test_integration.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -from pathlib import Path - -from dirty_equals import IsAnyStr -from typer.testing import CliRunner - -from bump_pydantic.__main__ import app - - -def test_integration(tmp_path: Path) -> None: - runner = CliRunner() - os.chdir(Path(__file__).parent.parent) - result = runner.invoke( - app, - [ - "project", - "--diff", - "--add-default-none", - "--rename-imports", - "--rename-methods", - "--replace-config-class", - "--replace-config-parameters", - ], - ) - assert result.exception is None - assert result.stdout.splitlines() == [ - "Inferring types... This may take a while.", - "Types are inferred.", - # NOTE: Replace `dict` by `model_dump`. - IsAnyStr(regex=".*/project/rename_method.py"), - IsAnyStr(regex=".*/project/rename_method.py"), - "@@ -1,4 +1,4 @@", - " from project.add_none import A", - " ", - " a = A(a=1, b=2, c=3, d=4)", - "-a.dict()", - "+a.model_dump()", - IsAnyStr(regex=".*/project/settings.py"), - IsAnyStr(regex=".*/project/settings.py"), - "@@ -1,4 +1,4 @@", - "-from pydantic import BaseSettings", - "+from pydantic_settings import BaseSettings", - " ", - " ", - " class Settings(BaseSettings):", - # NOTE: Add `None` to the fields. - IsAnyStr(regex=".*/project/add_none.py"), - IsAnyStr(regex=".*/project/add_none.py"), - "@@ -4,8 +4,8 @@", - " ", - " ", - " class A(BaseModel):", - "- a: int | None", - "- b: Optional[int]", - "- c: Union[int, None]", - "- d: Any", - "+ a: int | None = None", - "+ b: Optional[int] = None", - "+ c: Union[int, None] = None", - "+ d: Any = None", - " e: Dict[str, str]", - "You'll need to manually replace the `Config` class to the `model_config` attribute.", - IsAnyStr(), - # NOTE: Rename `Config` class to `model_config` attribute. - IsAnyStr(regex=".*/project/config_to_model.py"), - IsAnyStr(regex=".*/project/config_to_model.py"), - "@@ -1,10 +1,8 @@", - "-from pydantic import BaseModel", - "+from pydantic import ConfigDict, BaseModel", - " ", - " ", - " class A(BaseModel):", - "- class Config:", - "- orm_mode = True", - "- validate_all = True", - "+ model_config = ConfigDict(orm_mode=True, validate_all=True)", - " ", - " ", - " class BaseConfig:", - ] diff --git a/tests/test_replace_config.py b/tests/test_replace_config.py new file mode 100644 index 0000000..2194789 --- /dev/null +++ b/tests/test_replace_config.py @@ -0,0 +1,42 @@ +from libcst.codemod import CodemodTest + +from bump_pydantic.codemods.replace_config import ReplaceConfigCodemod + + +class TestReplaceConfigCommand(CodemodTest): + TRANSFORM = ReplaceConfigCodemod + + def test_config(self) -> None: + before = """ + from pydantic import BaseModel + + class Potato(BaseModel): + class Config: + allow_arbitrary_types = True + """ + after = """ + from pydantic import ConfigDict, BaseModel + + class Potato(BaseModel): + model_config = ConfigDict(allow_arbitrary_types=True) + """ + self.assertCodemod(before, after) + + def test_noop_config(self) -> None: + code = """ + from pydantic import BaseModel + + class Potato: + class Config: + allow_mutation = True + """ + self.assertCodemod(code, code) + + def test_global_config_class(self) -> None: + code = """ + from pydantic import BaseModel as Potato + + class Config: + allow_arbitrary_types = True + """ + self.assertCodemod(code, code) diff --git a/tests/test_replace_imports.py b/tests/test_replace_imports.py new file mode 100644 index 0000000..1a6e16b --- /dev/null +++ b/tests/test_replace_imports.py @@ -0,0 +1,96 @@ +import pytest +from libcst.codemod import CodemodTest + +from bump_pydantic.codemods.replace_imports import ReplaceImportsCodemod + + +class TestReplaceImportsCommand(CodemodTest): + TRANSFORM = ReplaceImportsCodemod + + def test_base_settings(self) -> None: + before = """ + from pydantic import BaseSettings + """ + after = """ + from pydantic_settings import BaseSettings + """ + self.assertCodemod(before, after) + + def test_noop_base_settings(self) -> None: + code = """ + from potato import BaseSettings + """ + self.assertCodemod(code, code) + + @pytest.mark.xfail(reason="To be implemented.") + def test_base_settings_as(self) -> None: + before = """ + from pydantic import BaseSettings as Potato + """ + after = """ + from pydantic_settings import BaseSettings as Potato + """ + self.assertCodemod(before, after) + + def test_color(self) -> None: + before = """ + from pydantic import Color + """ + after = """ + from pydantic_extra_types.color import Color + """ + self.assertCodemod(before, after) + + def test_color_full(self) -> None: + before = """ + from pydantic.color import Color + """ + after = """ + from pydantic_extra_types.color import Color + """ + self.assertCodemod(before, after) + + def test_noop_color(self) -> None: + code = """ + from potato import Color + """ + self.assertCodemod(code, code) + + def test_payment_card_number(self) -> None: + before = """ + from pydantic import PaymentCardNumber + """ + after = """ + from pydantic_extra_types.payment import PaymentCardNumber + """ + self.assertCodemod(before, after) + + def test_payment_card_brand(self) -> None: + before = """ + from pydantic.payment import PaymentCardBrand + """ + after = """ + from pydantic_extra_types.payment import PaymentCardBrand + """ + self.assertCodemod(before, after) + + def test_noop_payment_card_number(self) -> None: + code = """ + from potato import PaymentCardNumber + """ + self.assertCodemod(code, code) + + def test_noop_payment_card_brand(self) -> None: + code = """ + from potato import PaymentCardBrand + """ + self.assertCodemod(code, code) + + def test_both_payment(self) -> None: + before = """ + from pydantic.payment import PaymentCardNumber, PaymentCardBrand + """ + after = """ + from pydantic_extra_types.payment import PaymentCardBrand, PaymentCardNumber + """ + self.assertCodemod(before, after)