From 368b2969cf3da9d660d3771c17664cf9011af373 Mon Sep 17 00:00:00 2001 From: Denis Artyushin Date: Sun, 18 Feb 2024 17:52:41 +0300 Subject: [PATCH] Add json_schema_extra check --- bump_pydantic/codemods/field.py | 29 +++++++++++++++++++++++++++++ tests/unit/test_field.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/bump_pydantic/codemods/field.py b/bump_pydantic/codemods/field.py index 5ba7f47..da694de 100644 --- a/bump_pydantic/codemods/field.py +++ b/bump_pydantic/codemods/field.py @@ -4,6 +4,7 @@ from libcst import matchers as m from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor +from pydantic import Field RENAMED_KEYWORDS = { "min_items": "min_length", @@ -111,9 +112,26 @@ def leave_field_call(self, original_node: cst.Call, updated_node: cst.Call) -> c if not self.has_field_import or not self.inside_field_assign: return updated_node + json_schema_extra_elements: List[cst.DictElement] = [] new_args: List[cst.Arg] = [] for arg in updated_node.args: if m.matches(arg, m.Arg(keyword=m.Name())): + if arg.keyword.value == "json_schema_extra": + json_schema_extra_elements.extend(arg.value.elements) # type: ignore + continue + + if ( + (arg.keyword.value not in RENAMED_KEYWORDS) + and (arg.keyword.value not in Field.__annotations__) + and (arg.keyword != "extra") + ): + new_dict_element = cst.DictElement( + key=cst.SimpleString(value=f'"{arg.keyword.value}"'), value=arg.value, + + ) + json_schema_extra_elements.append(new_dict_element) + continue + keyword = RENAMED_KEYWORDS.get(arg.keyword.value, arg.keyword.value) # type: ignore value = arg.value if arg.keyword: @@ -131,6 +149,17 @@ def leave_field_call(self, original_node: cst.Call, updated_node: cst.Call) -> c else: new_args.append(arg) + if len(json_schema_extra_elements) > 0: + extra_arg = cst.Arg( + value=cst.Dict(elements=json_schema_extra_elements), + keyword=cst.Name(value="json_schema_extra"), + equal=cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace("") + ), + ) + + new_args.append(extra_arg) + return updated_node.with_changes(args=new_args) diff --git a/tests/unit/test_field.py b/tests/unit/test_field.py index 448aa01..3fad806 100644 --- a/tests/unit/test_field.py +++ b/tests/unit/test_field.py @@ -151,3 +151,33 @@ class Settings(BaseSettings): potato: int = Field(..., examples=[1]) """ self.assertCodemod(before, after) + + def test_json_schema_extra_exist(self) -> None: + before = """ + from pydantic import BaseModel, Field + + class Human(BaseModel): + name: str = Field(..., some_extra_field="some_extra_field_value", json_schema_extra={"a": "b"}) + """ + after = """ + from pydantic import BaseModel, Field + + class Human(BaseModel): + name: str = Field(..., json_schema_extra={"some_extra_field": "some_extra_field_value", "a": "b"}) + """ + self.assertCodemod(before, after) + + def test_json_schema_extra_not_exist(self) -> None: + before = """ + from pydantic import BaseModel, Field + + class Human(BaseModel): + name: str = Field(..., min_length=1, some_extra_field="some_extra_field_value") + """ + after = """ + from pydantic import BaseModel, Field + + class Human(BaseModel): + name: str = Field(..., min_length=1, json_schema_extra={"some_extra_field": "some_extra_field_value"}) + """ + self.assertCodemod(before, after)