Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Support const=True to Literal[T] #41

Merged
merged 1 commit into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions bump_pydantic/codemods/field.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List
from typing import List, Union

import libcst as cst
from libcst import matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor

RENAMED_KEYWORDS = {
"min_items": "min_length",
Expand Down Expand Up @@ -63,11 +63,33 @@ def leave_field_import(self, original_node: cst.Module, updated_node: cst.Module
@m.visit(m.AnnAssign(value=m.Call(func=m.Name("Field"))))
def visit_field_assign(self, node: cst.AnnAssign) -> None:
self.inside_field_assign = True
self._const: Union[cst.Arg, None] = None

@m.leave(m.AnnAssign(value=m.Call(func=m.Name("Field"))))
def leave_field_assign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign:
self.inside_field_assign = False
return updated_node

if self._const is None:
return updated_node

AddImportsVisitor.add_needed_import(self.context, "typing", "Literal")
RemoveImportsVisitor.remove_unused_import(self.context, "pydantic", "Field")
return updated_node.with_changes(
annotation=cst.Annotation(
annotation=cst.Subscript(
value=cst.Name("Literal"),
slice=[cst.SubscriptElement(slice=cst.Index(value=self._const.value))],
)
),
value=self._const.value,
)

@m.visit(m.Call(func=m.Name("Field")))
def visit_field_call(self, node: cst.Call) -> None:
# Check if there's a `const=True` argument.
const_arg = m.Arg(value=m.Name("True"), keyword=m.Name("const"))
if m.matches(node, m.Call(func=m.Name("Field"), args=[~m.Arg(value=m.Name("...")), const_arg])):
self._const = node.args[0]

@m.leave(m.Call(func=m.Name("Field")))
def leave_field_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
Expand Down
2 changes: 1 addition & 1 deletion bump_pydantic/codemods/replace_generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def leave_generic_model(self, original_node: cst.ClassDef, updated_node: cst.Cla
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="BaseModel")
return updated_node.with_changes(
bases=[
base if not m.matches(base, GENERIC_MODEL_ARG) else cst.Arg(value=cst.Name("BaseModel"))
cst.Arg(value=cst.Name("BaseModel")) if m.matches(base, GENERIC_MODEL_ARG) else base
for base in updated_node.bases
]
)
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,31 @@ class Potato(BaseModel):
potato: int = Field(..., env="POTATO")
"""
self.assertCodemod(code, code)

def test_replace_const_by_literal_type(self) -> None:
before = """
from enum import Enum

from pydantic import BaseModel, Field


class MyEnum(Enum):
POTATO = "potato"

class Potato(BaseModel):
potato: MyEnum = Field(MyEnum.POTATO, const=True)
"""
after = """
from enum import Enum

from pydantic import BaseModel
from typing import Literal


class MyEnum(Enum):
POTATO = "potato"

class Potato(BaseModel):
potato: Literal[MyEnum.POTATO] = MyEnum.POTATO
"""
self.assertCodemod(before, after)