diff --git a/bump_pydantic/codemods/field.py b/bump_pydantic/codemods/field.py index 01a2f50..4ddf62a 100644 --- a/bump_pydantic/codemods/field.py +++ b/bump_pydantic/codemods/field.py @@ -12,13 +12,44 @@ "regex": "pattern", } +IMPORT_FIELD = m.Module( + body=[ + m.ZeroOrMore(), + m.SimpleStatementLine( + body=[ + m.ZeroOrMore(), + m.ImportFrom( + module=m.Name("pydantic"), + names=[ + m.ZeroOrMore(), + m.ImportAlias(name=m.Name("Field")), + m.ZeroOrMore(), + ], + ), + m.ZeroOrMore(), + ], + ), + m.ZeroOrMore(), + ] +) + class FieldCodemod(VisitorBasedCodemodCommand): def __init__(self, context: CodemodContext) -> None: super().__init__(context) + self.has_field_import = False self.inside_field_assign = False + @m.visit(IMPORT_FIELD) + def visit_field_import(self, node: cst.Module) -> None: + self.has_field_import = True + + @m.leave(IMPORT_FIELD) + def leave_field_import(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + self.has_field_import = False + return updated_node + @m.visit(m.AnnAssign(value=m.Call(func=m.Name("Field")))) def visit_field_assign(self, node: cst.AnnAssign) -> None: self.inside_field_assign = True @@ -30,7 +61,7 @@ def leave_field_assign(self, original_node: cst.AnnAssign, updated_node: cst.Ann @m.leave(m.Call(func=m.Name("Field"))) def leave_field_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: - if not self.inside_field_assign: + if not self.has_field_import or not self.inside_field_assign: return updated_node new_args: List[cst.Arg] = [] diff --git a/tests/unit/test_field.py b/tests/unit/test_field.py index dd123fb..c420b7f 100644 --- a/tests/unit/test_field.py +++ b/tests/unit/test_field.py @@ -1,3 +1,4 @@ +import pytest from libcst.codemod import CodemodTest from bump_pydantic.codemods.field import FieldCodemod @@ -10,15 +11,41 @@ class TestReplaceConfigCommand(CodemodTest): def test_field_rename(self) -> None: before = """ - from pydantic import BaseModel + from pydantic import BaseModel, Field class Potato(BaseModel): potato: List[int] = Field(..., min_items=1, max_items=10) """ after = """ - from pydantic import BaseModel + from pydantic import BaseModel, Field class Potato(BaseModel): potato: List[int] = Field(..., min_length=1, max_length=10) """ self.assertCodemod(before, after) + + def test_noop(self) -> None: + code = """ + from pydantic import BaseModel + from potato import Field + + class Potato(BaseModel): + potato: List[int] = Field(..., max_items=1) + """ + self.assertCodemod(code, code) + + @pytest.mark.xfail(reason="Not implemented yet") + def test_field_rename_with_pydantic_import(self) -> None: + before = """ + import pydantic + + class Potato(pydantic.BaseModel): + potato: List[int] = pydantic.Field(..., min_items=1, max_items=10) + """ + after = """ + from pydantic import BaseModel, Field + + class Potato(pydantic.BaseModel): + potato: List[int] = pydantic.Field(..., min_length=1, max_length=10) + """ + self.assertCodemod(before, after)