diff --git a/bump_pydantic/codemods/field.py b/bump_pydantic/codemods/field.py index 0f0eda8..54c5439 100644 --- a/bump_pydantic/codemods/field.py +++ b/bump_pydantic/codemods/field.py @@ -1,9 +1,9 @@ -from typing import List +from typing import List, Union import libcst as cst from libcst import matchers as m from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand -from libcst.codemod.visitors import AddImportsVisitor +from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor RENAMED_KEYWORDS = { "min_items": "min_length", @@ -63,11 +63,33 @@ def leave_field_import(self, original_node: cst.Module, updated_node: cst.Module @m.visit(m.AnnAssign(value=m.Call(func=m.Name("Field")))) def visit_field_assign(self, node: cst.AnnAssign) -> None: self.inside_field_assign = True + self._const: Union[cst.Arg, None] = None @m.leave(m.AnnAssign(value=m.Call(func=m.Name("Field")))) def leave_field_assign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign: self.inside_field_assign = False - return updated_node + + if self._const is None: + return updated_node + + AddImportsVisitor.add_needed_import(self.context, "typing", "Literal") + RemoveImportsVisitor.remove_unused_import(self.context, "pydantic", "Field") + return updated_node.with_changes( + annotation=cst.Annotation( + annotation=cst.Subscript( + value=cst.Name("Literal"), + slice=[cst.SubscriptElement(slice=cst.Index(value=self._const.value))], + ) + ), + value=self._const.value, + ) + + @m.visit(m.Call(func=m.Name("Field"))) + def visit_field_call(self, node: cst.Call) -> None: + # Check if there's a `const=True` argument. + const_arg = m.Arg(value=m.Name("True"), keyword=m.Name("const")) + if m.matches(node, m.Call(func=m.Name("Field"), args=[~m.Arg(value=m.Name("...")), const_arg])): + self._const = node.args[0] @m.leave(m.Call(func=m.Name("Field"))) def leave_field_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: diff --git a/bump_pydantic/codemods/replace_generic_model.py b/bump_pydantic/codemods/replace_generic_model.py index 9607bca..5b741fa 100644 --- a/bump_pydantic/codemods/replace_generic_model.py +++ b/bump_pydantic/codemods/replace_generic_model.py @@ -18,7 +18,7 @@ def leave_generic_model(self, original_node: cst.ClassDef, updated_node: cst.Cla AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="BaseModel") return updated_node.with_changes( bases=[ - base if not m.matches(base, GENERIC_MODEL_ARG) else cst.Arg(value=cst.Name("BaseModel")) + cst.Arg(value=cst.Name("BaseModel")) if m.matches(base, GENERIC_MODEL_ARG) else base for base in updated_node.bases ] ) diff --git a/tests/unit/test_field.py b/tests/unit/test_field.py index 50aca70..be7a8ea 100644 --- a/tests/unit/test_field.py +++ b/tests/unit/test_field.py @@ -74,3 +74,31 @@ class Potato(BaseModel): potato: int = Field(..., env="POTATO") """ self.assertCodemod(code, code) + + def test_replace_const_by_literal_type(self) -> None: + before = """ + from enum import Enum + + from pydantic import BaseModel, Field + + + class MyEnum(Enum): + POTATO = "potato" + + class Potato(BaseModel): + potato: MyEnum = Field(MyEnum.POTATO, const=True) + """ + after = """ + from enum import Enum + + from pydantic import BaseModel + from typing import Literal + + + class MyEnum(Enum): + POTATO = "potato" + + class Potato(BaseModel): + potato: Literal[MyEnum.POTATO] = MyEnum.POTATO + """ + self.assertCodemod(before, after)