Skip to content

Commit

Permalink
Add json_schema_extra check
Browse files Browse the repository at this point in the history
  • Loading branch information
denisart committed Feb 18, 2024
1 parent e97380c commit 368b296
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
29 changes: 29 additions & 0 deletions bump_pydantic/codemods/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from libcst import matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from pydantic import Field

RENAMED_KEYWORDS = {
"min_items": "min_length",
Expand Down Expand Up @@ -111,9 +112,26 @@ def leave_field_call(self, original_node: cst.Call, updated_node: cst.Call) -> c
if not self.has_field_import or not self.inside_field_assign:
return updated_node

json_schema_extra_elements: List[cst.DictElement] = []
new_args: List[cst.Arg] = []
for arg in updated_node.args:
if m.matches(arg, m.Arg(keyword=m.Name())):
if arg.keyword.value == "json_schema_extra":
json_schema_extra_elements.extend(arg.value.elements) # type: ignore
continue

if (
(arg.keyword.value not in RENAMED_KEYWORDS)
and (arg.keyword.value not in Field.__annotations__)
and (arg.keyword != "extra")
):
new_dict_element = cst.DictElement(
key=cst.SimpleString(value=f'"{arg.keyword.value}"'), value=arg.value,

)
json_schema_extra_elements.append(new_dict_element)
continue

keyword = RENAMED_KEYWORDS.get(arg.keyword.value, arg.keyword.value) # type: ignore
value = arg.value
if arg.keyword:
Expand All @@ -131,6 +149,17 @@ def leave_field_call(self, original_node: cst.Call, updated_node: cst.Call) -> c
else:
new_args.append(arg)

if len(json_schema_extra_elements) > 0:
extra_arg = cst.Arg(
value=cst.Dict(elements=json_schema_extra_elements),
keyword=cst.Name(value="json_schema_extra"),
equal=cst.AssignEqual(
whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace("")
),
)

new_args.append(extra_arg)

return updated_node.with_changes(args=new_args)


Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,33 @@ class Settings(BaseSettings):
potato: int = Field(..., examples=[1])
"""
self.assertCodemod(before, after)

def test_json_schema_extra_exist(self) -> None:
before = """
from pydantic import BaseModel, Field
class Human(BaseModel):
name: str = Field(..., some_extra_field="some_extra_field_value", json_schema_extra={"a": "b"})
"""
after = """
from pydantic import BaseModel, Field
class Human(BaseModel):
name: str = Field(..., json_schema_extra={"some_extra_field": "some_extra_field_value", "a": "b"})
"""
self.assertCodemod(before, after)

def test_json_schema_extra_not_exist(self) -> None:
before = """
from pydantic import BaseModel, Field
class Human(BaseModel):
name: str = Field(..., min_length=1, some_extra_field="some_extra_field_value")
"""
after = """
from pydantic import BaseModel, Field
class Human(BaseModel):
name: str = Field(..., min_length=1, json_schema_extra={"some_extra_field": "some_extra_field_value"})
"""
self.assertCodemod(before, after)

0 comments on commit 368b296

Please sign in to comment.