Skip to content

Commit

Permalink
Fix replace() use in AgentSpec (#326)
Browse files Browse the repository at this point in the history
* Fix replace use in AgentSpec

* Filter out  from replace

* Inform perform_self_test has no effect

* Update smarts/core/agent.py

Co-authored-by: adai <[email protected]>

* Make format

Co-authored-by: adai <[email protected]>
  • Loading branch information
Gamenot and Adaickalavan authored Dec 22, 2020
1 parent 5e46a1d commit 0ad2ace
Showing 1 changed file with 32 additions and 3 deletions.
35 changes: 32 additions & 3 deletions smarts/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,15 @@ def __post_init__(self):
if self.policy_builder:
logger.warning(
f"[DEPRECATED] Please use AgentSpec(agent_builder=<...>) instead of AgentSpec(policy_builder=<..>):\n {self}"
"policy_builder will overwrite agent_builder"
)
assert self.agent_builder is None, self.agent_builder
self.agent_builder = self.policy_builder

if self.policy_params:
logger.warning(
f"[DEPRECATED] Please use AgentSpec(agent_params=<...>) instead of AgentSpec(policy_params=<..>):\n {self}"
"policy_builder will overwrite agent_builder"
)
assert self.agent_params is None, self.agent_params
self.agent_params = self.policy_params

self.policy_params = self.agent_params
Expand All @@ -146,7 +146,36 @@ def __post_init__(self):
def replace(self, **kwargs) -> "AgentSpec":
"""Return a copy of this AgentSpec with the given fields updated."""

return replace(self, **kwargs)
replacements = [
("policy_builder", "agent_builder"),
("policy_params", "agent_params"),
("perform_self_test", None),
]

assert (
None not in kwargs
), f"Error: kwargs input to replace() function contains invalid key `None`: {kwargs}"

kwargs_copy = kwargs.copy()
for deprecated, current in replacements:
if deprecated in kwargs:
if current:
logger.warning(
f"[DEPRECATED] Please use AgentSpec.replace({current}=<...>) instead of AgentSpec.replace({deprecated}=<...>)\n"
)
else:
logger.warning(
f"[DEPRECATED] Attribute {deprecated} no longer has effect."
)
assert (
current not in kwargs
), f"Mixed current ({current}) and deprecated ({deprecated}) values in replace"
moved = kwargs[deprecated]
del kwargs_copy[deprecated]
if current:
kwargs_copy[current] = moved

return replace(self, **kwargs_copy)

def build_agent(self) -> Agent:
"""Construct an Agent from the AgentSpec configuration."""
Expand Down

0 comments on commit 0ad2ace

Please sign in to comment.