Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-episode wrapper for AEC and parallel envs #1105

Merged
merged 17 commits into from
Sep 27, 2023
4 changes: 1 addition & 3 deletions pettingzoo/classic/rlcard_envs/rlcard_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@ def reset(self, seed=None, options=None):
self.truncations = self._convert_to_dict(
[False for _ in range(self.num_agents)]
)
self.infos = self._convert_to_dict(
[{"legal_moves": []} for _ in range(self.num_agents)]
)
self.infos = self._convert_to_dict([{} for _ in range(self.num_agents)])
self.next_legal_moves = list(sorted(obs["legal_actions"]))
self._last_obs = obs["obs"]

Expand Down
2 changes: 2 additions & 0 deletions pettingzoo/utils/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@
from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper
from pettingzoo.utils.wrappers.capture_stdout import CaptureStdoutWrapper
from pettingzoo.utils.wrappers.clip_out_of_bounds import ClipOutOfBoundsWrapper
from pettingzoo.utils.wrappers.multi_episode_env import MultiEpisodeEnv
from pettingzoo.utils.wrappers.multi_episode_parallel_env import MultiEpisodeParallelEnv
from pettingzoo.utils.wrappers.order_enforcing import OrderEnforcingWrapper
from pettingzoo.utils.wrappers.terminate_illegal import TerminateIllegalWrapper
87 changes: 87 additions & 0 deletions pettingzoo/utils/wrappers/multi_episode_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

import copy

from pettingzoo.utils.env import ActionType, AECEnv
from pettingzoo.utils.wrappers.base import BaseWrapper


class MultiEpisodeEnv(BaseWrapper):
"""Creates a new environment using the base environment that runs for `num_episodes` before truncating.

This is useful for creating evaluation environments.
When there are no more valid agents in the underlying environment, the environment is automatically reset.
After `num_episodes` have been run internally, the environment terminates normally.
The result of this wrapper is that the environment is no longer Markovian around the environment reset.
"""

def __init__(self, env: AECEnv, num_episodes: int):
"""__init__.

Args:
env (AECEnv): env
num_episodes (int): num_episodes
"""
assert isinstance(
env, AECEnv
), "MultiEpisodeEnv is only compatible with AEC environments"
super().__init__(env)

self._num_episodes = num_episodes

def reset(self, seed: int | None = None, options: dict | None = None) -> None:
"""reset.

Args:
seed (int | None): seed
options (dict | None): options

Returns:
None:
"""
self._episodes_elapsed = 1
self._seed = copy.deepcopy(seed)
self._options = copy.deepcopy(options)
super().reset(seed=seed, options=options)

def step(self, action: ActionType) -> None:
"""Steps the underlying environment for `num_episodes`.

This is useful for creating evaluation environments.
When there are no more valid agents in the underlying environment, the environment is automatically reset.
After `num_episodes` have been run internally, the environment terminates normally.
The result of this wrapper is that the environment is no longer Markovian around the environment reset.

Args:
action (ActionType): action

Returns:
None:
"""
super().step(action)
if self.agents:
return

# if we've crossed num_episodes, truncate all agents
# and let the environment terminate normally
if self._episodes_elapsed >= self._num_episodes:
self.truncations = {agent: True for agent in self.agents}
return

# if no more agents and haven't had enough episodes,
# increment the number of episodes and the seed for reset
self._episodes_elapsed += 1
self._seed = self._seed + 1 if self._seed else None
super().reset(seed=self._seed, options=self._options)
self.truncations = {agent: False for agent in self.agents}
self.terminations = {agent: False for agent in self.agents}

def __str__(self) -> str:
"""__str__.

Args:

Returns:
str:
"""
return str(self.env)
102 changes: 102 additions & 0 deletions pettingzoo/utils/wrappers/multi_episode_parallel_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from __future__ import annotations

import copy

from pettingzoo.utils.env import ActionType, AgentID, ObsType, ParallelEnv
from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper


class MultiEpisodeParallelEnv(BaseParallelWrapper):
"""Creates a new environment using the base environment that runs for `num_episodes` before truncating.

This is useful for creating evaluation environments.
When there are no more valid agents in the underlying environment, the environment is automatically reset.
When this happens, the `observation` and `info` returned by `step()` are replaced with that of the reset environment.
The result of this wrapper is that the environment is no longer Markovian around the environment reset.
"""

def __init__(self, env: ParallelEnv, num_episodes: int):
"""__init__.

Args:
env (AECEnv): the base environment
num_episodes (int): the number of episodes to run the underlying environment
"""
super().__init__(env)
assert isinstance(
env, ParallelEnv
), "MultiEpisodeEnv is only compatible with ParallelEnv environments."

self._num_episodes = num_episodes

def reset(
self, seed: int | None = None, options: dict | None = None
) -> tuple[dict[AgentID, ObsType], dict[AgentID, dict]]:
"""reset.

Args:
seed (int | None): seed for resetting the environment
options (dict | None): options

Returns:
tuple[dict[AgentID, ObsType], dict[AgentID, dict]]:
"""
obs, info = super().reset(seed=seed, options=options)

self._seed = copy.deepcopy(seed)
self._options = copy.deepcopy(options)
self._episodes_elapsed = 1

return obs, info

def step(
self, actions: dict[AgentID, ActionType]
) -> tuple[
dict[AgentID, ObsType],
dict[AgentID, float],
dict[AgentID, bool],
dict[AgentID, bool],
dict[AgentID, dict],
]:
"""Steps the environment.

When there are no more valid agents in the underlying environment, the environment is automatically reset.
When this happens, the `observation` and `info` returned by `step()` are replaced with that of the reset environment.
The result of this wrapper is that the environment is no longer Markovian around the environment reset.

Args:
actions (dict[AgentID, ActionType]): dictionary mapping of `AgentID`s to actions

Returns:
tuple[
dict[AgentID, ObsType],
dict[AgentID, float],
dict[AgentID, bool],
dict[AgentID, bool],
dict[AgentID, dict],
]:
"""
obs, rew, term, trunc, info = super().step(actions)
term = {agent: False for agent in term}
trunc = {agent: False for agent in term}

if self.agents:
return obs, rew, term, trunc, info

# override the term trunc to only trunc when num_episodes have been elapsed
if self._episodes_elapsed >= self._num_episodes:
term = {agent: False for agent in term}
trunc = {agent: True for agent in term}
return obs, rew, term, trunc, info

# if any agent terminates or truncates
# and we haven't elapsed `num_episodes`
# reset the environment
# we also override the observation and infos
# the result is that this env is no longer Markovian
# at the reset points
# increment the number of episodes and the seed for reset
self._episodes_elapsed += 1
self._seed = self._seed + 1 if self._seed else None
obs, info = super().reset(seed=self._seed, options=self._options)
return obs, rew, term, trunc, info
69 changes: 69 additions & 0 deletions test/wrapper_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

import pytest

from pettingzoo.butterfly import pistonball_v6
from pettingzoo.classic import texas_holdem_no_limit_v6
from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv


@pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6])
def test_multi_episode_env_wrapper(num_episodes: int) -> None:
"""test_multi_episode_env_wrapper.

The number of steps per environment are dictated by the seeding of the action space, not the environment.

Args:
num_episodes: number of episodes to run the MultiEpisodeEnv
"""
env = texas_holdem_no_limit_v6.env(num_players=3)
env = MultiEpisodeEnv(env, num_episodes=num_episodes)
env.reset(seed=42)

steps = 0
for agent in env.agent_iter():
steps += 1
obs, rew, term, trunc, info = env.last()

if term or trunc:
action = None
else:
action_space = env.action_space(agent)
action_space.seed(0)
action = action_space.sample(mask=obs["action_mask"])

env.step(action)

env.close()

assert (
steps == num_episodes * 6
), f"Expected to have 6 steps per episode, got {steps / num_episodes}."


@pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6])
def test_multi_episode_parallel_env_wrapper(num_episodes) -> None:
"""test_multi_episode_parallel_env_wrapper.

The default action for this test is to move all pistons down. This results in an episode length of 125.

Args:
num_episodes: number of episodes to run the MultiEpisodeEnv
"""
env = pistonball_v6.parallel_env()
env = MultiEpisodeParallelEnv(env, num_episodes=num_episodes)
_ = env.reset(seed=42)

steps = 0
while env.agents:
steps += 1
# this is where you would insert your policy
actions = {agent: env.action_space(agent).low for agent in env.agents}

_ = env.step(actions)

env.close()

assert (
steps == num_episodes * 125
), f"Expected to have 125 steps per episode, got {steps / num_episodes}."