From 0ad2aceedf14d270703d3534837bbe10010cdee7 Mon Sep 17 00:00:00 2001 From: Tucker Date: Tue, 22 Dec 2020 13:14:37 -0500 Subject: [PATCH] Fix `replace()` use in AgentSpec (#326) * Fix replace use in AgentSpec * Filter out from replace * Inform perform_self_test has no effect * Update smarts/core/agent.py Co-authored-by: adai * Make format Co-authored-by: adai --- smarts/core/agent.py | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/smarts/core/agent.py b/smarts/core/agent.py index a7bb7ea57a..04a3e7ef26 100644 --- a/smarts/core/agent.py +++ b/smarts/core/agent.py @@ -129,15 +129,15 @@ def __post_init__(self): if self.policy_builder: logger.warning( f"[DEPRECATED] Please use AgentSpec(agent_builder=<...>) instead of AgentSpec(policy_builder=<..>):\n {self}" + "policy_builder will overwrite agent_builder" ) - assert self.agent_builder is None, self.agent_builder self.agent_builder = self.policy_builder if self.policy_params: logger.warning( f"[DEPRECATED] Please use AgentSpec(agent_params=<...>) instead of AgentSpec(policy_params=<..>):\n {self}" + "policy_builder will overwrite agent_builder" ) - assert self.agent_params is None, self.agent_params self.agent_params = self.policy_params self.policy_params = self.agent_params @@ -146,7 +146,36 @@ def __post_init__(self): def replace(self, **kwargs) -> "AgentSpec": """Return a copy of this AgentSpec with the given fields updated.""" - return replace(self, **kwargs) + replacements = [ + ("policy_builder", "agent_builder"), + ("policy_params", "agent_params"), + ("perform_self_test", None), + ] + + assert ( + None not in kwargs + ), f"Error: kwargs input to replace() function contains invalid key `None`: {kwargs}" + + kwargs_copy = kwargs.copy() + for deprecated, current in replacements: + if deprecated in kwargs: + if current: + logger.warning( + f"[DEPRECATED] Please use AgentSpec.replace({current}=<...>) instead of AgentSpec.replace({deprecated}=<...>)\n" + ) + else: + logger.warning( + f"[DEPRECATED] Attribute {deprecated} no longer has effect." + ) + assert ( + current not in kwargs + ), f"Mixed current ({current}) and deprecated ({deprecated}) values in replace" + moved = kwargs[deprecated] + del kwargs_copy[deprecated] + if current: + kwargs_copy[current] = moved + + return replace(self, **kwargs_copy) def build_agent(self) -> Agent: """Construct an Agent from the AgentSpec configuration."""