Skip to content

Commit

Permalink
Add multi-episode wrapper for AEC and parallel envs (#1105)
Browse files Browse the repository at this point in the history
Co-authored-by: elliottower <[email protected]>
  • Loading branch information
jjshoots and elliottower authored Sep 27, 2023
1 parent 6c33934 commit e907e05
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 3 deletions.
4 changes: 1 addition & 3 deletions pettingzoo/classic/rlcard_envs/rlcard_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@ def reset(self, seed=None, options=None):
self.truncations = self._convert_to_dict(
[False for _ in range(self.num_agents)]
)
self.infos = self._convert_to_dict(
[{"legal_moves": []} for _ in range(self.num_agents)]
)
self.infos = self._convert_to_dict([{} for _ in range(self.num_agents)])
self.next_legal_moves = list(sorted(obs["legal_actions"]))
self._last_obs = obs["obs"]

Expand Down
2 changes: 2 additions & 0 deletions pettingzoo/utils/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@
from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper
from pettingzoo.utils.wrappers.capture_stdout import CaptureStdoutWrapper
from pettingzoo.utils.wrappers.clip_out_of_bounds import ClipOutOfBoundsWrapper
from pettingzoo.utils.wrappers.multi_episode_env import MultiEpisodeEnv
from pettingzoo.utils.wrappers.multi_episode_parallel_env import MultiEpisodeParallelEnv
from pettingzoo.utils.wrappers.order_enforcing import OrderEnforcingWrapper
from pettingzoo.utils.wrappers.terminate_illegal import TerminateIllegalWrapper
87 changes: 87 additions & 0 deletions pettingzoo/utils/wrappers/multi_episode_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

import copy

from pettingzoo.utils.env import ActionType, AECEnv
from pettingzoo.utils.wrappers.base import BaseWrapper


class MultiEpisodeEnv(BaseWrapper):
"""Creates a new environment using the base environment that runs for `num_episodes` before truncating.
This is useful for creating evaluation environments.
When there are no more valid agents in the underlying environment, the environment is automatically reset.
After `num_episodes` have been run internally, the environment terminates normally.
The result of this wrapper is that the environment is no longer Markovian around the environment reset.
"""

def __init__(self, env: AECEnv, num_episodes: int):
"""__init__.
Args:
env (AECEnv): env
num_episodes (int): num_episodes
"""
assert isinstance(
env, AECEnv
), "MultiEpisodeEnv is only compatible with AEC environments"
super().__init__(env)

self._num_episodes = num_episodes

def reset(self, seed: int | None = None, options: dict | None = None) -> None:
"""reset.
Args:
seed (int | None): seed
options (dict | None): options
Returns:
None:
"""
self._episodes_elapsed = 1
self._seed = copy.deepcopy(seed)
self._options = copy.deepcopy(options)
super().reset(seed=seed, options=options)

def step(self, action: ActionType) -> None:
"""Steps the underlying environment for `num_episodes`.
This is useful for creating evaluation environments.
When there are no more valid agents in the underlying environment, the environment is automatically reset.
After `num_episodes` have been run internally, the environment terminates normally.
The result of this wrapper is that the environment is no longer Markovian around the environment reset.
Args:
action (ActionType): action
Returns:
None:
"""
super().step(action)
if self.agents:
return

# if we've crossed num_episodes, truncate all agents
# and let the environment terminate normally
if self._episodes_elapsed >= self._num_episodes:
self.truncations = {agent: True for agent in self.agents}
return

# if no more agents and haven't had enough episodes,
# increment the number of episodes and the seed for reset
self._episodes_elapsed += 1
self._seed = self._seed + 1 if self._seed else None
super().reset(seed=self._seed, options=self._options)
self.truncations = {agent: False for agent in self.agents}
self.terminations = {agent: False for agent in self.agents}

def __str__(self) -> str:
"""__str__.
Args:
Returns:
str:
"""
return str(self.env)
102 changes: 102 additions & 0 deletions pettingzoo/utils/wrappers/multi_episode_parallel_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from __future__ import annotations

import copy

from pettingzoo.utils.env import ActionType, AgentID, ObsType, ParallelEnv
from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper


class MultiEpisodeParallelEnv(BaseParallelWrapper):
"""Creates a new environment using the base environment that runs for `num_episodes` before truncating.
This is useful for creating evaluation environments.
When there are no more valid agents in the underlying environment, the environment is automatically reset.
When this happens, the `observation` and `info` returned by `step()` are replaced with that of the reset environment.
The result of this wrapper is that the environment is no longer Markovian around the environment reset.
"""

def __init__(self, env: ParallelEnv, num_episodes: int):
"""__init__.
Args:
env (AECEnv): the base environment
num_episodes (int): the number of episodes to run the underlying environment
"""
super().__init__(env)
assert isinstance(
env, ParallelEnv
), "MultiEpisodeEnv is only compatible with ParallelEnv environments."

self._num_episodes = num_episodes

def reset(
self, seed: int | None = None, options: dict | None = None
) -> tuple[dict[AgentID, ObsType], dict[AgentID, dict]]:
"""reset.
Args:
seed (int | None): seed for resetting the environment
options (dict | None): options
Returns:
tuple[dict[AgentID, ObsType], dict[AgentID, dict]]:
"""
obs, info = super().reset(seed=seed, options=options)

self._seed = copy.deepcopy(seed)
self._options = copy.deepcopy(options)
self._episodes_elapsed = 1

return obs, info

def step(
self, actions: dict[AgentID, ActionType]
) -> tuple[
dict[AgentID, ObsType],
dict[AgentID, float],
dict[AgentID, bool],
dict[AgentID, bool],
dict[AgentID, dict],
]:
"""Steps the environment.
When there are no more valid agents in the underlying environment, the environment is automatically reset.
When this happens, the `observation` and `info` returned by `step()` are replaced with that of the reset environment.
The result of this wrapper is that the environment is no longer Markovian around the environment reset.
Args:
actions (dict[AgentID, ActionType]): dictionary mapping of `AgentID`s to actions
Returns:
tuple[
dict[AgentID, ObsType],
dict[AgentID, float],
dict[AgentID, bool],
dict[AgentID, bool],
dict[AgentID, dict],
]:
"""
obs, rew, term, trunc, info = super().step(actions)
term = {agent: False for agent in term}
trunc = {agent: False for agent in term}

if self.agents:
return obs, rew, term, trunc, info

# override the term trunc to only trunc when num_episodes have been elapsed
if self._episodes_elapsed >= self._num_episodes:
term = {agent: False for agent in term}
trunc = {agent: True for agent in term}
return obs, rew, term, trunc, info

# if any agent terminates or truncates
# and we haven't elapsed `num_episodes`
# reset the environment
# we also override the observation and infos
# the result is that this env is no longer Markovian
# at the reset points
# increment the number of episodes and the seed for reset
self._episodes_elapsed += 1
self._seed = self._seed + 1 if self._seed else None
obs, info = super().reset(seed=self._seed, options=self._options)
return obs, rew, term, trunc, info
69 changes: 69 additions & 0 deletions test/wrapper_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

import pytest

from pettingzoo.butterfly import pistonball_v6
from pettingzoo.classic import texas_holdem_no_limit_v6
from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv


@pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6])
def test_multi_episode_env_wrapper(num_episodes: int) -> None:
"""test_multi_episode_env_wrapper.
The number of steps per environment are dictated by the seeding of the action space, not the environment.
Args:
num_episodes: number of episodes to run the MultiEpisodeEnv
"""
env = texas_holdem_no_limit_v6.env(num_players=3)
env = MultiEpisodeEnv(env, num_episodes=num_episodes)
env.reset(seed=42)

steps = 0
for agent in env.agent_iter():
steps += 1
obs, rew, term, trunc, info = env.last()

if term or trunc:
action = None
else:
action_space = env.action_space(agent)
action_space.seed(0)
action = action_space.sample(mask=obs["action_mask"])

env.step(action)

env.close()

assert (
steps == num_episodes * 6
), f"Expected to have 6 steps per episode, got {steps / num_episodes}."


@pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6])
def test_multi_episode_parallel_env_wrapper(num_episodes) -> None:
"""test_multi_episode_parallel_env_wrapper.
The default action for this test is to move all pistons down. This results in an episode length of 125.
Args:
num_episodes: number of episodes to run the MultiEpisodeEnv
"""
env = pistonball_v6.parallel_env()
env = MultiEpisodeParallelEnv(env, num_episodes=num_episodes)
_ = env.reset(seed=42)

steps = 0
while env.agents:
steps += 1
# this is where you would insert your policy
actions = {agent: env.action_space(agent).low for agent in env.agents}

_ = env.step(actions)

env.close()

assert (
steps == num_episodes * 125
), f"Expected to have 125 steps per episode, got {steps / num_episodes}."

0 comments on commit e907e05

Please sign in to comment.