-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support PettingZoo Parallel API and action mask (#305)
* support PettingZoo Parallel API and action mask * Fix pre-commit * Add pettingzoo to github test deps * add missing reset_global_context() * fix ContinuousActionDistribution to receive action_mask but ignore the value for now * exclude action_mask in default_make_actor_critic_func * add docs for action mask and pettingzoo env * improve doc * apply action_mask to epsilons * use original probs when all actions are masked * apply softmax with mask * remove unnecessary computations * fix to apply mask to probs * improve masked_softmax and masked_log_softmax for extreme cases --------- Co-authored-by: Aleksei Petrenko <[email protected]>
- Loading branch information
1 parent
91d4322
commit abbc459
Showing
21 changed files
with
493 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Action Masking | ||
|
||
Action masking is a technique used to restrict the set of actions available to an agent in certain states. This can be particularly useful in environments where some actions are invalid or undesirable in specific situations. See [paper](https://arxiv.org/abs/2006.14171) for more details. | ||
|
||
## Implementing Action Masking | ||
|
||
To implement action masking in your environment, you need to add an `action_mask` field to the observation dictionary returned by your environment. Here's how to do it: | ||
|
||
1. Define the action mask space in your environment's observation space | ||
2. Generate and return the action mask in both `reset()` and `step()` methods | ||
|
||
Here's an example of a custom environment implementing action masking: | ||
|
||
```python | ||
import gymnasium as gym | ||
import numpy as np | ||
|
||
class CustomEnv(gym.Env): | ||
def __init__(self, full_env_name, cfg, render_mode=None): | ||
... | ||
self.observation_space = gym.spaces.Dict({ | ||
"obs": gym.spaces.Box(low=0, high=1, shape=(3, 3, 2), dtype=np.int8), | ||
"action_mask": gym.spaces.Box(low=0, high=1, shape=(9,), dtype=np.int8), | ||
}) | ||
self.action_space = gym.spaces.Discrete(9) | ||
|
||
def reset(self, **kwargs): | ||
... | ||
# Initial action mask that allows all actions | ||
action_mask = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1]) | ||
return {"obs": obs, "action_mask": action_mask}, info | ||
|
||
def step(self, action): | ||
... | ||
# Generate new action mask based on the current state | ||
action_mask = np.array([1, 0, 0, 1, 1, 1, 0, 1, 1]) | ||
return {"obs": obs, "action_mask": action_mask}, reward, terminated, truncated, info | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# PettingZoo | ||
|
||
[PettingZoo](https://pettingzoo.farama.org/) is a Python library for conducting research in multi-agent reinforcement learning. This guide explains how to use PettingZoo environments with Sample Factory. | ||
|
||
## Installation | ||
|
||
Install Sample Factory with PettingZoo dependencies with PyPI: | ||
|
||
```bash | ||
pip install -e sample-factory[pettingzoo] | ||
``` | ||
|
||
## Running Experiments | ||
|
||
Run PettingZoo experiments with the scripts in `sf_examples`. | ||
The default parameters are not tuned for throughput. | ||
|
||
To train a model in the `tictactoe_v3` environment: | ||
|
||
``` | ||
python -m sf_examples.train_pettingzoo_env --algo=APPO --env=tictactoe_v3 --experiment="Experiment Name" | ||
``` | ||
|
||
To visualize the training results, use the `enjoy_pettingzoo_env` script: | ||
|
||
``` | ||
python -m sf_examples.enjoy_pettingzoo_env --env=tictactoe_v3 --experiment="Experiment Name" | ||
``` | ||
|
||
Currently, the scripts in `sf_examples` are set up for the `tictactoe_v3` environment. To use other PettingZoo environments, you'll need to modify the scripts or add your own as explained below. | ||
|
||
### Adding a new PettingZoo environment | ||
|
||
To add a new PettingZoo environment, follow the instructions from [Custom environments](../03-customization/custom-environments.md), with the additional step of wrapping your PettingZoo environment with `sample_factory.envs.pettingzoo_envs.PettingZooParallelEnv`. | ||
|
||
Here's an example of how to create a factory function for a PettingZoo environment: | ||
|
||
```python | ||
from sample_factory.envs.pettingzoo_envs import PettingZooParallelEnv | ||
import some_pettingzoo_env # Import your desired PettingZoo environment | ||
|
||
def make_pettingzoo_env(full_env_name, cfg=None, env_config=None, render_mode=None): | ||
return PettingZooParallelEnv(some_pettingzoo_env.parallel_env(render_mode=render_mode)) | ||
``` | ||
|
||
Note: Sample Factory supports only the [Parallel API](https://pettingzoo.farama.org/api/parallel/) of PettingZoo. If your environment uses the AEC API, you can convert it to Parallel API using `pettingzoo.utils.conversions.aec_to_parallel` or `pettingzoo.utils.conversions.turn_based_aec_to_parallel`. Be aware that these conversions have some limitations. For more details, refer to the [PettingZoo documentation](https://pettingzoo.farama.org/api/wrappers/pz_wrappers/). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
""" | ||
Gym env wrappers for PettingZoo -> Gymnasium transition. | ||
""" | ||
|
||
import gymnasium as gym | ||
|
||
|
||
class PettingZooParallelEnv(gym.Env): | ||
def __init__(self, env): | ||
if not all_equal([env.observation_space(a) for a in env.possible_agents]): | ||
raise ValueError("All observation spaces must be equal") | ||
|
||
if not all_equal([env.action_space(a) for a in env.possible_agents]): | ||
raise ValueError("All action spaces must be equal") | ||
|
||
self.env = env | ||
self.metadata = env.metadata | ||
self.render_mode = env.render_mode if hasattr(env, "render_mode") else env.unwrapped.render_mode | ||
self.observation_space = normalize_observation_space(env.observation_space(env.possible_agents[0])) | ||
self.action_space = env.action_space(env.possible_agents[0]) | ||
self.num_agents = env.max_num_agents | ||
self.is_multiagent = True | ||
|
||
def reset(self, **kwargs): | ||
obs, infos = self.env.reset(**kwargs) | ||
obs = [normalize_observation(obs.get(a)) for a in self.env.possible_agents] | ||
infos = [infos[a] if a in infos else dict(is_active=False) for a in self.env.possible_agents] | ||
return obs, infos | ||
|
||
def step(self, actions): | ||
actions = dict(zip(self.env.possible_agents, actions)) | ||
obs, rewards, terminations, truncations, infos = self.env.step(actions) | ||
|
||
if not self.env.agents: | ||
obs, infos = self.env.reset() | ||
|
||
obs = [normalize_observation(obs.get(a)) for a in self.env.possible_agents] | ||
rewards = [rewards.get(a) for a in self.env.possible_agents] | ||
terminations = [terminations.get(a) for a in self.env.possible_agents] | ||
truncations = [truncations.get(a) for a in self.env.possible_agents] | ||
infos = [normalize_info(infos[a], a) if a in infos else dict(is_active=False) for a in self.env.possible_agents] | ||
return obs, rewards, terminations, truncations, infos | ||
|
||
def render(self): | ||
return self.env.render() | ||
|
||
def close(self): | ||
self.env.close() | ||
|
||
|
||
def all_equal(l_) -> bool: | ||
return all(v == l_[0] for v in l_) | ||
|
||
|
||
def normalize_observation_space(obs_space): | ||
"""Normalize observation space with the key "obs" that's specially handled as the main value.""" | ||
if isinstance(obs_space, gym.spaces.Dict) and "observation" in obs_space.spaces: | ||
spaces = dict(obs_space.spaces) | ||
spaces["obs"] = spaces["observation"] | ||
del spaces["observation"] | ||
obs_space = gym.spaces.Dict(spaces) | ||
|
||
return obs_space | ||
|
||
|
||
def normalize_observation(obs): | ||
if isinstance(obs, dict) and "observation" in obs: | ||
obs["obs"] = obs["observation"] | ||
del obs["observation"] | ||
|
||
return obs | ||
|
||
|
||
def normalize_info(info, agent): | ||
"""active_agent is available when using `turn_based_aec_to_parallel` of PettingZoo.""" | ||
if isinstance(info, dict) and "active_agent" in info: | ||
info["is_active"] = info["active_agent"] == agent | ||
|
||
return info |
Oops, something went wrong.