Skip to content

Commit

Permalink
✨ Support const=True to Literal[T] (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Jun 29, 2023
1 parent a4a7c7d commit aeb5dcc
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 4 deletions.
28 changes: 25 additions & 3 deletions bump_pydantic/codemods/field.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List
from typing import List, Union

import libcst as cst
from libcst import matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor

RENAMED_KEYWORDS = {
"min_items": "min_length",
Expand Down Expand Up @@ -63,11 +63,33 @@ def leave_field_import(self, original_node: cst.Module, updated_node: cst.Module
@m.visit(m.AnnAssign(value=m.Call(func=m.Name("Field"))))
def visit_field_assign(self, node: cst.AnnAssign) -> None:
self.inside_field_assign = True
self._const: Union[cst.Arg, None] = None

@m.leave(m.AnnAssign(value=m.Call(func=m.Name("Field"))))
def leave_field_assign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign:
self.inside_field_assign = False
return updated_node

if self._const is None:
return updated_node

AddImportsVisitor.add_needed_import(self.context, "typing", "Literal")
RemoveImportsVisitor.remove_unused_import(self.context, "pydantic", "Field")
return updated_node.with_changes(
annotation=cst.Annotation(
annotation=cst.Subscript(
value=cst.Name("Literal"),
slice=[cst.SubscriptElement(slice=cst.Index(value=self._const.value))],
)
),
value=self._const.value,
)

@m.visit(m.Call(func=m.Name("Field")))
def visit_field_call(self, node: cst.Call) -> None:
# Check if there's a `const=True` argument.
const_arg = m.Arg(value=m.Name("True"), keyword=m.Name("const"))
if m.matches(node, m.Call(func=m.Name("Field"), args=[~m.Arg(value=m.Name("...")), const_arg])):
self._const = node.args[0]

@m.leave(m.Call(func=m.Name("Field")))
def leave_field_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
Expand Down
2 changes: 1 addition & 1 deletion bump_pydantic/codemods/replace_generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def leave_generic_model(self, original_node: cst.ClassDef, updated_node: cst.Cla
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="BaseModel")
return updated_node.with_changes(
bases=[
base if not m.matches(base, GENERIC_MODEL_ARG) else cst.Arg(value=cst.Name("BaseModel"))
cst.Arg(value=cst.Name("BaseModel")) if m.matches(base, GENERIC_MODEL_ARG) else base
for base in updated_node.bases
]
)
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,31 @@ class Potato(BaseModel):
potato: int = Field(..., env="POTATO")
"""
self.assertCodemod(code, code)

def test_replace_const_by_literal_type(self) -> None:
before = """
from enum import Enum
from pydantic import BaseModel, Field
class MyEnum(Enum):
POTATO = "potato"
class Potato(BaseModel):
potato: MyEnum = Field(MyEnum.POTATO, const=True)
"""
after = """
from enum import Enum
from pydantic import BaseModel
from typing import Literal
class MyEnum(Enum):
POTATO = "potato"
class Potato(BaseModel):
potato: Literal[MyEnum.POTATO] = MyEnum.POTATO
"""
self.assertCodemod(before, after)

0 comments on commit aeb5dcc

Please sign in to comment.