-
-
Notifications
You must be signed in to change notification settings - Fork 412
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add multi-episode wrapper for AEC and parallel envs (#1105)
Co-authored-by: elliottower <[email protected]>
- Loading branch information
1 parent
6c33934
commit e907e05
Showing
5 changed files
with
261 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from __future__ import annotations | ||
|
||
import copy | ||
|
||
from pettingzoo.utils.env import ActionType, AECEnv | ||
from pettingzoo.utils.wrappers.base import BaseWrapper | ||
|
||
|
||
class MultiEpisodeEnv(BaseWrapper): | ||
"""Creates a new environment using the base environment that runs for `num_episodes` before truncating. | ||
This is useful for creating evaluation environments. | ||
When there are no more valid agents in the underlying environment, the environment is automatically reset. | ||
After `num_episodes` have been run internally, the environment terminates normally. | ||
The result of this wrapper is that the environment is no longer Markovian around the environment reset. | ||
""" | ||
|
||
def __init__(self, env: AECEnv, num_episodes: int): | ||
"""__init__. | ||
Args: | ||
env (AECEnv): env | ||
num_episodes (int): num_episodes | ||
""" | ||
assert isinstance( | ||
env, AECEnv | ||
), "MultiEpisodeEnv is only compatible with AEC environments" | ||
super().__init__(env) | ||
|
||
self._num_episodes = num_episodes | ||
|
||
def reset(self, seed: int | None = None, options: dict | None = None) -> None: | ||
"""reset. | ||
Args: | ||
seed (int | None): seed | ||
options (dict | None): options | ||
Returns: | ||
None: | ||
""" | ||
self._episodes_elapsed = 1 | ||
self._seed = copy.deepcopy(seed) | ||
self._options = copy.deepcopy(options) | ||
super().reset(seed=seed, options=options) | ||
|
||
def step(self, action: ActionType) -> None: | ||
"""Steps the underlying environment for `num_episodes`. | ||
This is useful for creating evaluation environments. | ||
When there are no more valid agents in the underlying environment, the environment is automatically reset. | ||
After `num_episodes` have been run internally, the environment terminates normally. | ||
The result of this wrapper is that the environment is no longer Markovian around the environment reset. | ||
Args: | ||
action (ActionType): action | ||
Returns: | ||
None: | ||
""" | ||
super().step(action) | ||
if self.agents: | ||
return | ||
|
||
# if we've crossed num_episodes, truncate all agents | ||
# and let the environment terminate normally | ||
if self._episodes_elapsed >= self._num_episodes: | ||
self.truncations = {agent: True for agent in self.agents} | ||
return | ||
|
||
# if no more agents and haven't had enough episodes, | ||
# increment the number of episodes and the seed for reset | ||
self._episodes_elapsed += 1 | ||
self._seed = self._seed + 1 if self._seed else None | ||
super().reset(seed=self._seed, options=self._options) | ||
self.truncations = {agent: False for agent in self.agents} | ||
self.terminations = {agent: False for agent in self.agents} | ||
|
||
def __str__(self) -> str: | ||
"""__str__. | ||
Args: | ||
Returns: | ||
str: | ||
""" | ||
return str(self.env) |
102 changes: 102 additions & 0 deletions
102
pettingzoo/utils/wrappers/multi_episode_parallel_env.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from __future__ import annotations | ||
|
||
import copy | ||
|
||
from pettingzoo.utils.env import ActionType, AgentID, ObsType, ParallelEnv | ||
from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper | ||
|
||
|
||
class MultiEpisodeParallelEnv(BaseParallelWrapper): | ||
"""Creates a new environment using the base environment that runs for `num_episodes` before truncating. | ||
This is useful for creating evaluation environments. | ||
When there are no more valid agents in the underlying environment, the environment is automatically reset. | ||
When this happens, the `observation` and `info` returned by `step()` are replaced with that of the reset environment. | ||
The result of this wrapper is that the environment is no longer Markovian around the environment reset. | ||
""" | ||
|
||
def __init__(self, env: ParallelEnv, num_episodes: int): | ||
"""__init__. | ||
Args: | ||
env (AECEnv): the base environment | ||
num_episodes (int): the number of episodes to run the underlying environment | ||
""" | ||
super().__init__(env) | ||
assert isinstance( | ||
env, ParallelEnv | ||
), "MultiEpisodeEnv is only compatible with ParallelEnv environments." | ||
|
||
self._num_episodes = num_episodes | ||
|
||
def reset( | ||
self, seed: int | None = None, options: dict | None = None | ||
) -> tuple[dict[AgentID, ObsType], dict[AgentID, dict]]: | ||
"""reset. | ||
Args: | ||
seed (int | None): seed for resetting the environment | ||
options (dict | None): options | ||
Returns: | ||
tuple[dict[AgentID, ObsType], dict[AgentID, dict]]: | ||
""" | ||
obs, info = super().reset(seed=seed, options=options) | ||
|
||
self._seed = copy.deepcopy(seed) | ||
self._options = copy.deepcopy(options) | ||
self._episodes_elapsed = 1 | ||
|
||
return obs, info | ||
|
||
def step( | ||
self, actions: dict[AgentID, ActionType] | ||
) -> tuple[ | ||
dict[AgentID, ObsType], | ||
dict[AgentID, float], | ||
dict[AgentID, bool], | ||
dict[AgentID, bool], | ||
dict[AgentID, dict], | ||
]: | ||
"""Steps the environment. | ||
When there are no more valid agents in the underlying environment, the environment is automatically reset. | ||
When this happens, the `observation` and `info` returned by `step()` are replaced with that of the reset environment. | ||
The result of this wrapper is that the environment is no longer Markovian around the environment reset. | ||
Args: | ||
actions (dict[AgentID, ActionType]): dictionary mapping of `AgentID`s to actions | ||
Returns: | ||
tuple[ | ||
dict[AgentID, ObsType], | ||
dict[AgentID, float], | ||
dict[AgentID, bool], | ||
dict[AgentID, bool], | ||
dict[AgentID, dict], | ||
]: | ||
""" | ||
obs, rew, term, trunc, info = super().step(actions) | ||
term = {agent: False for agent in term} | ||
trunc = {agent: False for agent in term} | ||
|
||
if self.agents: | ||
return obs, rew, term, trunc, info | ||
|
||
# override the term trunc to only trunc when num_episodes have been elapsed | ||
if self._episodes_elapsed >= self._num_episodes: | ||
term = {agent: False for agent in term} | ||
trunc = {agent: True for agent in term} | ||
return obs, rew, term, trunc, info | ||
|
||
# if any agent terminates or truncates | ||
# and we haven't elapsed `num_episodes` | ||
# reset the environment | ||
# we also override the observation and infos | ||
# the result is that this env is no longer Markovian | ||
# at the reset points | ||
# increment the number of episodes and the seed for reset | ||
self._episodes_elapsed += 1 | ||
self._seed = self._seed + 1 if self._seed else None | ||
obs, info = super().reset(seed=self._seed, options=self._options) | ||
return obs, rew, term, trunc, info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
|
||
from pettingzoo.butterfly import pistonball_v6 | ||
from pettingzoo.classic import texas_holdem_no_limit_v6 | ||
from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv | ||
|
||
|
||
@pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6]) | ||
def test_multi_episode_env_wrapper(num_episodes: int) -> None: | ||
"""test_multi_episode_env_wrapper. | ||
The number of steps per environment are dictated by the seeding of the action space, not the environment. | ||
Args: | ||
num_episodes: number of episodes to run the MultiEpisodeEnv | ||
""" | ||
env = texas_holdem_no_limit_v6.env(num_players=3) | ||
env = MultiEpisodeEnv(env, num_episodes=num_episodes) | ||
env.reset(seed=42) | ||
|
||
steps = 0 | ||
for agent in env.agent_iter(): | ||
steps += 1 | ||
obs, rew, term, trunc, info = env.last() | ||
|
||
if term or trunc: | ||
action = None | ||
else: | ||
action_space = env.action_space(agent) | ||
action_space.seed(0) | ||
action = action_space.sample(mask=obs["action_mask"]) | ||
|
||
env.step(action) | ||
|
||
env.close() | ||
|
||
assert ( | ||
steps == num_episodes * 6 | ||
), f"Expected to have 6 steps per episode, got {steps / num_episodes}." | ||
|
||
|
||
@pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6]) | ||
def test_multi_episode_parallel_env_wrapper(num_episodes) -> None: | ||
"""test_multi_episode_parallel_env_wrapper. | ||
The default action for this test is to move all pistons down. This results in an episode length of 125. | ||
Args: | ||
num_episodes: number of episodes to run the MultiEpisodeEnv | ||
""" | ||
env = pistonball_v6.parallel_env() | ||
env = MultiEpisodeParallelEnv(env, num_episodes=num_episodes) | ||
_ = env.reset(seed=42) | ||
|
||
steps = 0 | ||
while env.agents: | ||
steps += 1 | ||
# this is where you would insert your policy | ||
actions = {agent: env.action_space(agent).low for agent in env.agents} | ||
|
||
_ = env.step(actions) | ||
|
||
env.close() | ||
|
||
assert ( | ||
steps == num_episodes * 125 | ||
), f"Expected to have 125 steps per episode, got {steps / num_episodes}." |