diff --git a/pettingzoo/classic/rlcard_envs/rlcard_base.py b/pettingzoo/classic/rlcard_envs/rlcard_base.py index 5aed2116f..ab6651e02 100644 --- a/pettingzoo/classic/rlcard_envs/rlcard_base.py +++ b/pettingzoo/classic/rlcard_envs/rlcard_base.py @@ -129,9 +129,7 @@ def reset(self, seed=None, options=None): self.truncations = self._convert_to_dict( [False for _ in range(self.num_agents)] ) - self.infos = self._convert_to_dict( - [{"legal_moves": []} for _ in range(self.num_agents)] - ) + self.infos = self._convert_to_dict([{} for _ in range(self.num_agents)]) self.next_legal_moves = list(sorted(obs["legal_actions"])) self._last_obs = obs["obs"] diff --git a/pettingzoo/utils/wrappers/__init__.py b/pettingzoo/utils/wrappers/__init__.py index 494babb3c..0e94f106c 100644 --- a/pettingzoo/utils/wrappers/__init__.py +++ b/pettingzoo/utils/wrappers/__init__.py @@ -3,5 +3,7 @@ from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper from pettingzoo.utils.wrappers.capture_stdout import CaptureStdoutWrapper from pettingzoo.utils.wrappers.clip_out_of_bounds import ClipOutOfBoundsWrapper +from pettingzoo.utils.wrappers.multi_episode_env import MultiEpisodeEnv +from pettingzoo.utils.wrappers.multi_episode_parallel_env import MultiEpisodeParallelEnv from pettingzoo.utils.wrappers.order_enforcing import OrderEnforcingWrapper from pettingzoo.utils.wrappers.terminate_illegal import TerminateIllegalWrapper diff --git a/pettingzoo/utils/wrappers/multi_episode_env.py b/pettingzoo/utils/wrappers/multi_episode_env.py new file mode 100644 index 000000000..d924a0c45 --- /dev/null +++ b/pettingzoo/utils/wrappers/multi_episode_env.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import copy + +from pettingzoo.utils.env import ActionType, AECEnv +from pettingzoo.utils.wrappers.base import BaseWrapper + + +class MultiEpisodeEnv(BaseWrapper): + """Creates a new environment using the base environment that runs for `num_episodes` before truncating. + + This is useful for creating evaluation environments. + When there are no more valid agents in the underlying environment, the environment is automatically reset. + After `num_episodes` have been run internally, the environment terminates normally. + The result of this wrapper is that the environment is no longer Markovian around the environment reset. + """ + + def __init__(self, env: AECEnv, num_episodes: int): + """__init__. + + Args: + env (AECEnv): env + num_episodes (int): num_episodes + """ + assert isinstance( + env, AECEnv + ), "MultiEpisodeEnv is only compatible with AEC environments" + super().__init__(env) + + self._num_episodes = num_episodes + + def reset(self, seed: int | None = None, options: dict | None = None) -> None: + """reset. + + Args: + seed (int | None): seed + options (dict | None): options + + Returns: + None: + """ + self._episodes_elapsed = 1 + self._seed = copy.deepcopy(seed) + self._options = copy.deepcopy(options) + super().reset(seed=seed, options=options) + + def step(self, action: ActionType) -> None: + """Steps the underlying environment for `num_episodes`. + + This is useful for creating evaluation environments. + When there are no more valid agents in the underlying environment, the environment is automatically reset. + After `num_episodes` have been run internally, the environment terminates normally. + The result of this wrapper is that the environment is no longer Markovian around the environment reset. + + Args: + action (ActionType): action + + Returns: + None: + """ + super().step(action) + if self.agents: + return + + # if we've crossed num_episodes, truncate all agents + # and let the environment terminate normally + if self._episodes_elapsed >= self._num_episodes: + self.truncations = {agent: True for agent in self.agents} + return + + # if no more agents and haven't had enough episodes, + # increment the number of episodes and the seed for reset + self._episodes_elapsed += 1 + self._seed = self._seed + 1 if self._seed else None + super().reset(seed=self._seed, options=self._options) + self.truncations = {agent: False for agent in self.agents} + self.terminations = {agent: False for agent in self.agents} + + def __str__(self) -> str: + """__str__. + + Args: + + Returns: + str: + """ + return str(self.env) diff --git a/pettingzoo/utils/wrappers/multi_episode_parallel_env.py b/pettingzoo/utils/wrappers/multi_episode_parallel_env.py new file mode 100644 index 000000000..bbc9c9b60 --- /dev/null +++ b/pettingzoo/utils/wrappers/multi_episode_parallel_env.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import copy + +from pettingzoo.utils.env import ActionType, AgentID, ObsType, ParallelEnv +from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper + + +class MultiEpisodeParallelEnv(BaseParallelWrapper): + """Creates a new environment using the base environment that runs for `num_episodes` before truncating. + + This is useful for creating evaluation environments. + When there are no more valid agents in the underlying environment, the environment is automatically reset. + When this happens, the `observation` and `info` returned by `step()` are replaced with that of the reset environment. + The result of this wrapper is that the environment is no longer Markovian around the environment reset. + """ + + def __init__(self, env: ParallelEnv, num_episodes: int): + """__init__. + + Args: + env (AECEnv): the base environment + num_episodes (int): the number of episodes to run the underlying environment + """ + super().__init__(env) + assert isinstance( + env, ParallelEnv + ), "MultiEpisodeEnv is only compatible with ParallelEnv environments." + + self._num_episodes = num_episodes + + def reset( + self, seed: int | None = None, options: dict | None = None + ) -> tuple[dict[AgentID, ObsType], dict[AgentID, dict]]: + """reset. + + Args: + seed (int | None): seed for resetting the environment + options (dict | None): options + + Returns: + tuple[dict[AgentID, ObsType], dict[AgentID, dict]]: + """ + obs, info = super().reset(seed=seed, options=options) + + self._seed = copy.deepcopy(seed) + self._options = copy.deepcopy(options) + self._episodes_elapsed = 1 + + return obs, info + + def step( + self, actions: dict[AgentID, ActionType] + ) -> tuple[ + dict[AgentID, ObsType], + dict[AgentID, float], + dict[AgentID, bool], + dict[AgentID, bool], + dict[AgentID, dict], + ]: + """Steps the environment. + + When there are no more valid agents in the underlying environment, the environment is automatically reset. + When this happens, the `observation` and `info` returned by `step()` are replaced with that of the reset environment. + The result of this wrapper is that the environment is no longer Markovian around the environment reset. + + Args: + actions (dict[AgentID, ActionType]): dictionary mapping of `AgentID`s to actions + + Returns: + tuple[ + dict[AgentID, ObsType], + dict[AgentID, float], + dict[AgentID, bool], + dict[AgentID, bool], + dict[AgentID, dict], + ]: + """ + obs, rew, term, trunc, info = super().step(actions) + term = {agent: False for agent in term} + trunc = {agent: False for agent in term} + + if self.agents: + return obs, rew, term, trunc, info + + # override the term trunc to only trunc when num_episodes have been elapsed + if self._episodes_elapsed >= self._num_episodes: + term = {agent: False for agent in term} + trunc = {agent: True for agent in term} + return obs, rew, term, trunc, info + + # if any agent terminates or truncates + # and we haven't elapsed `num_episodes` + # reset the environment + # we also override the observation and infos + # the result is that this env is no longer Markovian + # at the reset points + # increment the number of episodes and the seed for reset + self._episodes_elapsed += 1 + self._seed = self._seed + 1 if self._seed else None + obs, info = super().reset(seed=self._seed, options=self._options) + return obs, rew, term, trunc, info diff --git a/test/wrapper_test.py b/test/wrapper_test.py new file mode 100644 index 000000000..650fe328b --- /dev/null +++ b/test/wrapper_test.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import pytest + +from pettingzoo.butterfly import pistonball_v6 +from pettingzoo.classic import texas_holdem_no_limit_v6 +from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv + + +@pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6]) +def test_multi_episode_env_wrapper(num_episodes: int) -> None: + """test_multi_episode_env_wrapper. + + The number of steps per environment are dictated by the seeding of the action space, not the environment. + + Args: + num_episodes: number of episodes to run the MultiEpisodeEnv + """ + env = texas_holdem_no_limit_v6.env(num_players=3) + env = MultiEpisodeEnv(env, num_episodes=num_episodes) + env.reset(seed=42) + + steps = 0 + for agent in env.agent_iter(): + steps += 1 + obs, rew, term, trunc, info = env.last() + + if term or trunc: + action = None + else: + action_space = env.action_space(agent) + action_space.seed(0) + action = action_space.sample(mask=obs["action_mask"]) + + env.step(action) + + env.close() + + assert ( + steps == num_episodes * 6 + ), f"Expected to have 6 steps per episode, got {steps / num_episodes}." + + +@pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6]) +def test_multi_episode_parallel_env_wrapper(num_episodes) -> None: + """test_multi_episode_parallel_env_wrapper. + + The default action for this test is to move all pistons down. This results in an episode length of 125. + + Args: + num_episodes: number of episodes to run the MultiEpisodeEnv + """ + env = pistonball_v6.parallel_env() + env = MultiEpisodeParallelEnv(env, num_episodes=num_episodes) + _ = env.reset(seed=42) + + steps = 0 + while env.agents: + steps += 1 + # this is where you would insert your policy + actions = {agent: env.action_space(agent).low for agent in env.agents} + + _ = env.step(actions) + + env.close() + + assert ( + steps == num_episodes * 125 + ), f"Expected to have 125 steps per episode, got {steps / num_episodes}."