From ef97ee0308c25bcb42b5d59f0ee70feb2ed3ae77 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 19 Jun 2023 17:27:14 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Use=20only=20attributes=20from?= =?UTF-8?q?=20the=20Config=20class?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bump_pydantic/codemods/replace_config.py | 19 +++++++++-------- tests/test_replace_config.py | 27 ++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/bump_pydantic/codemods/replace_config.py b/bump_pydantic/codemods/replace_config.py index 98f71bc..7c3e14f 100644 --- a/bump_pydantic/codemods/replace_config.py +++ b/bump_pydantic/codemods/replace_config.py @@ -90,16 +90,17 @@ def visit_Assign(self, node: cst.Assign) -> None: self.assign_value = node.value def visit_AssignTarget(self, node: cst.AssignTarget) -> None: - self.config_args.append( - cst.Arg( - keyword=node.target, # type: ignore[arg-type] - value=self.assign_value, - equal=cst.AssignEqual( - whitespace_before=cst.SimpleWhitespace(""), - whitespace_after=cst.SimpleWhitespace(""), - ), + if self.inside_config_class: + self.config_args.append( + cst.Arg( + keyword=node.target, # type: ignore[arg-type] + value=self.assign_value, + equal=cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ), + ) ) - ) def leave_Module(self, original_node: Module, updated_node: Module) -> Module: return updated_node diff --git a/tests/test_replace_config.py b/tests/test_replace_config.py index 2194789..3f1addb 100644 --- a/tests/test_replace_config.py +++ b/tests/test_replace_config.py @@ -40,3 +40,30 @@ class Config: allow_arbitrary_types = True """ self.assertCodemod(code, code) + + def test_reset_config_args(self) -> None: + before = """ + from pydantic import BaseModel + + class Potato(BaseModel): + class Config: + allow_arbitrary_types = True + + potato = Potato() + + class Potato2(BaseModel): + class Config: + allow_mutation = True + """ + after = """ + from pydantic import ConfigDict, BaseModel + + class Potato(BaseModel): + model_config = ConfigDict(allow_arbitrary_types=True) + + potato = Potato() + + class Potato2(BaseModel): + model_config = ConfigDict(allow_mutation=True) + """ + self.assertCodemod(before, after)