diff --git a/bump_pydantic/commands/use_settings.py b/bump_pydantic/commands/use_settings.py new file mode 100644 index 0000000..08e6d2a --- /dev/null +++ b/bump_pydantic/commands/use_settings.py @@ -0,0 +1,24 @@ +from libcst.codemod import CodemodContext +from libcst.codemod.commands.rename import RenameCommand + + +def UsePydanticSettingsCommand(context: CodemodContext): + """Support for pydantic.BaseSettings. + + This command will rename pydantic.BaseSettings to pydantic_settings:BaseSettings. + + It doesn't support the following cases: + - from pydantic.settings import BaseSettings + - import pydantic ... class Settings(pydantic.BaseSettings) + - import pydantic as pd ... class Settings(pd.BaseSettings) + + TODO: Support the above cases. To implement the above, you'll need to go to each + `ClassDef`, and see the bases. If there's a `pydantic.settings.BaseSettings` in the + bases, then you'll need to use `RemoveImportsVisitor` and `AddImportsVisitor` from + `libcst.codemod.visitors`. + """ + return RenameCommand( + context=context, + old_name="pydantic.BaseSettings", + new_name="pydantic_settings:BaseSettings", + ) diff --git a/bump_pydantic/transformers.py b/bump_pydantic/transformers.py index 791d973..8e9fd02 100644 --- a/bump_pydantic/transformers.py +++ b/bump_pydantic/transformers.py @@ -9,6 +9,7 @@ from bump_pydantic.commands.rename_method_call import RenameMethodCallCommand from bump_pydantic.commands.replace_call_param import ReplaceCallParam from bump_pydantic.commands.replace_config_class import ReplaceConfigClassByDict +from bump_pydantic.commands.use_settings import UsePydanticSettingsCommand CHANGED_IMPORTS = { "pydantic.tools": "pydantic.deprecated.tools", @@ -78,6 +79,10 @@ def gather_transformers( lambda context: RenameCommand(context, old_import, new_import) for old_import, new_import in CHANGED_IMPORTS.items() ) + # NOTE: Including this here, since there's an issue on RenameCommand, and + # UsePydanticSettingsCommand is just a wrapper - which could have been included + # on the list of changed imports above. + transformers.append(UsePydanticSettingsCommand) if add_default_none: transformers.append( diff --git a/project/settings.py b/project/settings.py new file mode 100644 index 0000000..d6064ec --- /dev/null +++ b/project/settings.py @@ -0,0 +1,5 @@ +from pydantic import BaseSettings + + +class Settings(BaseSettings): + a: int diff --git a/tests/commands/test_add_default_none.py b/tests/commands/test_add_default_none.py index 6cf8e31..210ea0f 100644 --- a/tests/commands/test_add_default_none.py +++ b/tests/commands/test_add_default_none.py @@ -1,6 +1,4 @@ import textwrap -from libcst.metadata import MetadataWrapper -from libcst_mypy import MypyTypeInferenceProvider from pathlib import Path import libcst as cst diff --git a/tests/commands/test_pydantic_settings.py b/tests/commands/test_pydantic_settings.py new file mode 100644 index 0000000..859b998 --- /dev/null +++ b/tests/commands/test_pydantic_settings.py @@ -0,0 +1,71 @@ +import pytest +from libcst.codemod import CodemodTest + +from bump_pydantic.commands.use_settings import UsePydanticSettingsCommand + + +class TestUsePydanticSettingsCommand(CodemodTest): + TRANSFORM = lambda _, context: UsePydanticSettingsCommand(context) + + def test_base_settings(self): + before = """ + from pydantic import BaseSettings + + class Settings(BaseSettings): + foo: str + """ + after = """ + from pydantic_settings import BaseSettings + + class Settings(BaseSettings): + foo: str + """ + self.assertCodemod(before, after) + + @pytest.mark.skip(reason="Not implemented yet") + def test_base_settings_import(self): + before = """ + from pydantic.settings import BaseSettings + + class Settings(BaseSettings): + foo: str + """ + after = """ + from pydantic_settings import BaseSettings + + class Settings(BaseSettings): + foo: str + """ + self.assertCodemod(before, after) + + @pytest.mark.skip(reason="Not implemented yet") + def test_base_settings_import_from(self): + before = """ + import pydantic + + class Settings(pydantic.BaseSettings): + foo: str + """ + after = """ + from pydantic_settings import BaseSettings + + class Settings(BaseSettings): + foo: str + """ + self.assertCodemod(before, after) + + @pytest.mark.skip(reason="Not implemented yet") + def test_base_settings_import_from_alias(self): + before = """ + import pydantic as pd + + class Settings(pd.BaseSettings): + foo: str + """ + after = """ + from pydantic_settings import BaseSettings + + class Settings(BaseSettings): + foo: str + """ + self.assertCodemod(before, after) diff --git a/tests/test_integration.py b/tests/test_integration.py index 612834f..252b8c7 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -35,6 +35,14 @@ def test_integration(tmp_path: Path) -> None: " a = A(a=1, b=2, c=3, d=4)", "-a.dict()", "+a.model_dump()", + IsAnyStr(regex=".*/project/settings.py"), + IsAnyStr(regex=".*/project/settings.py"), + "@@ -1,4 +1,4 @@", + "-from pydantic import BaseSettings", + "+from pydantic_settings import BaseSettings", + " ", + " ", + " class Settings(BaseSettings):", # NOTE: Add `None` to the fields. IsAnyStr(regex=".*/project/add_none.py"), IsAnyStr(regex=".*/project/add_none.py"),