diff --git a/bump_pydantic/codemods/replace_config.py b/bump_pydantic/codemods/replace_config.py index 98f71bc..7c3e14f 100644 --- a/bump_pydantic/codemods/replace_config.py +++ b/bump_pydantic/codemods/replace_config.py @@ -90,16 +90,17 @@ def visit_Assign(self, node: cst.Assign) -> None: self.assign_value = node.value def visit_AssignTarget(self, node: cst.AssignTarget) -> None: - self.config_args.append( - cst.Arg( - keyword=node.target, # type: ignore[arg-type] - value=self.assign_value, - equal=cst.AssignEqual( - whitespace_before=cst.SimpleWhitespace(""), - whitespace_after=cst.SimpleWhitespace(""), - ), + if self.inside_config_class: + self.config_args.append( + cst.Arg( + keyword=node.target, # type: ignore[arg-type] + value=self.assign_value, + equal=cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ), + ) ) - ) def leave_Module(self, original_node: Module, updated_node: Module) -> Module: return updated_node diff --git a/tests/test_replace_config.py b/tests/test_replace_config.py index 2194789..3f1addb 100644 --- a/tests/test_replace_config.py +++ b/tests/test_replace_config.py @@ -40,3 +40,30 @@ class Config: allow_arbitrary_types = True """ self.assertCodemod(code, code) + + def test_reset_config_args(self) -> None: + before = """ + from pydantic import BaseModel + + class Potato(BaseModel): + class Config: + allow_arbitrary_types = True + + potato = Potato() + + class Potato2(BaseModel): + class Config: + allow_mutation = True + """ + after = """ + from pydantic import ConfigDict, BaseModel + + class Potato(BaseModel): + model_config = ConfigDict(allow_arbitrary_types=True) + + potato = Potato() + + class Potato2(BaseModel): + model_config = ConfigDict(allow_mutation=True) + """ + self.assertCodemod(before, after)