Skip to content

Commit

Permalink
Support PettingZoo Parallel API and action mask (#305)
Browse files Browse the repository at this point in the history
* support PettingZoo Parallel API and action mask

* Fix pre-commit

* Add pettingzoo to github test deps

* add missing reset_global_context()

* fix ContinuousActionDistribution to receive action_mask but ignore the value for now

* exclude action_mask in default_make_actor_critic_func

* add docs for action mask and pettingzoo env

* improve doc

* apply action_mask to epsilons

* use original probs when all actions are masked

* apply softmax with mask

* remove unnecessary computations

* fix to apply mask to probs

* improve masked_softmax and masked_log_softmax for extreme cases

---------

Co-authored-by: Aleksei Petrenko <[email protected]>
  • Loading branch information
nkzawa and alex-petrenko authored Oct 23, 2024
1 parent 91d4322 commit abbc459
Show file tree
Hide file tree
Showing 21 changed files with 493 additions and 43 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- name: Install conda env & dependencies
run: |
conda install python=${{ matrix.python-version }}
pip install -e '.[atari, mujoco, envpool]'
pip install -e '.[atari, mujoco, envpool, pettingzoo]'
conda list
- name: Install test dependencies
run: |
Expand Down Expand Up @@ -75,7 +75,7 @@ jobs:
- name: Install conda env & dependencies
run: |
conda install python=${{ matrix.python-version }}
pip install -e '.[atari, mujoco]'
pip install -e '.[atari, mujoco, pettingzoo]'
conda list
- name: Install test dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ check-codestyle:
.PHONY: test

test:
pytest -s --maxfail=2
pytest -s --maxfail=2 -rA
# ; echo "Tests finished. You might need to type 'reset' and press Enter to fix the terminal window"


Expand Down
38 changes: 38 additions & 0 deletions docs/07-advanced-topics/action-masking.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Action Masking

Action masking is a technique used to restrict the set of actions available to an agent in certain states. This can be particularly useful in environments where some actions are invalid or undesirable in specific situations. See [paper](https://arxiv.org/abs/2006.14171) for more details.

## Implementing Action Masking

To implement action masking in your environment, you need to add an `action_mask` field to the observation dictionary returned by your environment. Here's how to do it:

1. Define the action mask space in your environment's observation space
2. Generate and return the action mask in both `reset()` and `step()` methods

Here's an example of a custom environment implementing action masking:

```python
import gymnasium as gym
import numpy as np

class CustomEnv(gym.Env):
def __init__(self, full_env_name, cfg, render_mode=None):
...
self.observation_space = gym.spaces.Dict({
"obs": gym.spaces.Box(low=0, high=1, shape=(3, 3, 2), dtype=np.int8),
"action_mask": gym.spaces.Box(low=0, high=1, shape=(9,), dtype=np.int8),
})
self.action_space = gym.spaces.Discrete(9)

def reset(self, **kwargs):
...
# Initial action mask that allows all actions
action_mask = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1])
return {"obs": obs, "action_mask": action_mask}, info

def step(self, action):
...
# Generate new action mask based on the current state
action_mask = np.array([1, 0, 0, 1, 1, 1, 0, 1, 1])
return {"obs": obs, "action_mask": action_mask}, reward, terminated, truncated, info
```
46 changes: 46 additions & 0 deletions docs/09-environment-integrations/pettingzoo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# PettingZoo

[PettingZoo](https://pettingzoo.farama.org/) is a Python library for conducting research in multi-agent reinforcement learning. This guide explains how to use PettingZoo environments with Sample Factory.

## Installation

Install Sample Factory with PettingZoo dependencies with PyPI:

```bash
pip install -e sample-factory[pettingzoo]
```

## Running Experiments

Run PettingZoo experiments with the scripts in `sf_examples`.
The default parameters are not tuned for throughput.

To train a model in the `tictactoe_v3` environment:

```
python -m sf_examples.train_pettingzoo_env --algo=APPO --env=tictactoe_v3 --experiment="Experiment Name"
```

To visualize the training results, use the `enjoy_pettingzoo_env` script:

```
python -m sf_examples.enjoy_pettingzoo_env --env=tictactoe_v3 --experiment="Experiment Name"
```

Currently, the scripts in `sf_examples` are set up for the `tictactoe_v3` environment. To use other PettingZoo environments, you'll need to modify the scripts or add your own as explained below.

### Adding a new PettingZoo environment

To add a new PettingZoo environment, follow the instructions from [Custom environments](../03-customization/custom-environments.md), with the additional step of wrapping your PettingZoo environment with `sample_factory.envs.pettingzoo_envs.PettingZooParallelEnv`.

Here's an example of how to create a factory function for a PettingZoo environment:

```python
from sample_factory.envs.pettingzoo_envs import PettingZooParallelEnv
import some_pettingzoo_env # Import your desired PettingZoo environment

def make_pettingzoo_env(full_env_name, cfg=None, env_config=None, render_mode=None):
return PettingZooParallelEnv(some_pettingzoo_env.parallel_env(render_mode=render_mode))
```

Note: Sample Factory supports only the [Parallel API](https://pettingzoo.farama.org/api/parallel/) of PettingZoo. If your environment uses the AEC API, you can convert it to Parallel API using `pettingzoo.utils.conversions.aec_to_parallel` or `pettingzoo.utils.conversions.turn_based_aec_to_parallel`. Be aware that these conversions have some limitations. For more details, refer to the [PettingZoo documentation](https://pettingzoo.farama.org/api/wrappers/pz_wrappers/).
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ nav:
- 07-advanced-topics/passing-info.md
- 07-advanced-topics/observer.md
- 07-advanced-topics/profiling.md
- 07-advanced-topics/action-masking.md
- Miscellaneous:
- 08-miscellaneous/tests.md
- 08-miscellaneous/v1-to-v2.md
Expand All @@ -170,6 +171,7 @@ nav:
- 09-environment-integrations/nethack.md
- 09-environment-integrations/brax.md
- 09-environment-integrations/swarm-rl.md
- 09-environment-integrations/pettingzoo.md
- Huggingface Integration:
- 10-huggingface/huggingface.md
- Release Notes:
Expand Down
5 changes: 4 additions & 1 deletion sample_factory/algo/sampling/inference_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,14 @@ def _handle_policy_steps(self, timing):
if actor_critic.training:
actor_critic.eval() # need to call this because we can be in serial mode

action_mask = (
ensure_torch_tensor(obs.pop("action_mask")).to(self.device) if "action_mask" in obs else None
)
normalized_obs = prepare_and_normalize_obs(actor_critic, obs)
rnn_states = ensure_torch_tensor(rnn_states).to(self.device).float()

with timing.add_time("forward"):
policy_outputs = actor_critic(normalized_obs, rnn_states)
policy_outputs = actor_critic(normalized_obs, rnn_states, action_mask=action_mask)
policy_outputs["policy_version"] = torch.empty([num_samples]).fill_(self.param_client.policy_version)

with timing.add_time("prepare_outputs"):
Expand Down
2 changes: 1 addition & 1 deletion sample_factory/algo/sampling/non_batched_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def _reset(self):

log.info("Decorrelating experience for %d frames...", decorrelate_steps)
for decorrelate_step in range(decorrelate_steps):
actions = [e.action_space.sample() for _ in range(self.num_agents)]
actions = [e.action_space.sample(obs.get("action_mask")) for obs in observations]
observations, rew, terminated, truncated, info = e.step(actions)

for agent_i, obs in enumerate(observations):
Expand Down
52 changes: 42 additions & 10 deletions sample_factory/algo/utils/action_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def is_continuous_action_space(action_space: ActionSpace) -> bool:
return isinstance(action_space, gym.spaces.Box)


def get_action_distribution(action_space, raw_logits):
def get_action_distribution(action_space, raw_logits, action_mask=None):
"""
Create the distribution object based on provided action space and unprocessed logits.
:param action_space: Gym action space object
Expand All @@ -52,9 +52,9 @@ def get_action_distribution(action_space, raw_logits):
assert calc_num_action_parameters(action_space) == raw_logits.shape[-1]

if isinstance(action_space, gym.spaces.Discrete):
return CategoricalActionDistribution(raw_logits)
return CategoricalActionDistribution(raw_logits, action_mask)
elif isinstance(action_space, gym.spaces.Tuple):
return TupleActionDistribution(action_space, logits_flat=raw_logits)
return TupleActionDistribution(action_space, logits_flat=raw_logits, action_mask=action_mask)
elif isinstance(action_space, gym.spaces.Box):
return ContinuousActionDistribution(params=raw_logits)
else:
Expand All @@ -81,35 +81,65 @@ def argmax_actions(distribution):
raise NotImplementedError(f"Action distribution type {type(distribution)} does not support argmax!")


def masked_softmax(logits, mask):
# Mask out the invalid logits by adding a large negative number (-1e9)
logits = logits + (mask == 0) * -1e9
result = functional.softmax(logits, dim=-1)
result = result * mask
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
return result


def masked_log_softmax(logits, mask):
logits = logits + (mask == 0) * -1e9
return functional.log_softmax(logits, dim=-1)


# noinspection PyAbstractClass
class CategoricalActionDistribution:
def __init__(self, raw_logits):
def __init__(self, raw_logits, action_mask=None):
"""
Ctor.
:param raw_logits: unprocessed logits, typically an output of a fully-connected layer
"""

self.raw_logits = raw_logits
self.action_mask = action_mask
self.log_p = self.p = None

@property
def probs(self):
if self.p is None:
self.p = functional.softmax(self.raw_logits, dim=-1)
if self.action_mask is not None:
self.p = masked_softmax(self.raw_logits, self.action_mask)
else:
self.p = functional.softmax(self.raw_logits, dim=-1)
return self.p

@property
def log_probs(self):
if self.log_p is None:
self.log_p = functional.log_softmax(self.raw_logits, dim=-1)
if self.action_mask is not None:
self.log_p = masked_log_softmax(self.raw_logits, self.action_mask)
else:
self.log_p = functional.log_softmax(self.raw_logits, dim=-1)
return self.log_p

def sample_gumbel(self):
sample = torch.argmax(self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_(), -1)
probs = self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_()
if self.action_mask is not None:
probs = probs * self.action_mask
sample = torch.argmax(probs, -1)
return sample

def sample(self):
samples = torch.multinomial(self.probs, 1, True)
probs = self.probs
if self.action_mask is not None:
all_zero = (probs.sum(dim=-1) == 0).unsqueeze(-1)
epsilons = torch.full_like(probs, 1e-6)
probs = torch.where(all_zero, epsilons, probs) # ensure sum of probabilities is non-zero

samples = torch.multinomial(probs, 1, True)
return samples

def log_prob(self, value):
Expand Down Expand Up @@ -181,16 +211,18 @@ class TupleActionDistribution:
"""

def __init__(self, action_space, logits_flat):
def __init__(self, action_space, logits_flat, action_mask=None):
self.logit_lengths = [calc_num_action_parameters(s) for s in action_space.spaces]
self.split_logits = torch.split(logits_flat, self.logit_lengths, dim=1)
self.action_lengths = [calc_num_actions(s) for s in action_space.spaces]
self.action_mask = action_mask

assert len(self.split_logits) == len(action_space.spaces)

self.distributions = []
for i, space in enumerate(action_space.spaces):
self.distributions.append(get_action_distribution(space, self.split_logits[i]))
action_mask = self.action_mask[i] if self.action_mask is not None else None
self.distributions.append(get_action_distribution(space, self.split_logits[i], action_mask))

@staticmethod
def _flatten_actions(list_of_action_batches):
Expand Down
4 changes: 4 additions & 0 deletions sample_factory/algo/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def set_global_context(ctx: SampleFactoryContext):


def reset_global_context():
"""
Most useful in tests, call this after any part of the global context has been modified
by a test in any way.
"""
global GLOBAL_CONTEXT
GLOBAL_CONTEXT = SampleFactoryContext()

Expand Down
4 changes: 3 additions & 1 deletion sample_factory/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def max_frames_reached(frames):
reward_list = []

obs, infos = env.reset()
action_mask = obs.pop("action_mask").to(device) if "action_mask" in obs else None
rnn_states = torch.zeros([env.num_agents, get_rnn_size(cfg)], dtype=torch.float32, device=device)
episode_reward = None
finished_episode = [False for _ in range(env.num_agents)]
Expand All @@ -149,7 +150,7 @@ def max_frames_reached(frames):

if not cfg.no_render:
visualize_policy_inputs(normalized_obs)
policy_outputs = actor_critic(normalized_obs, rnn_states)
policy_outputs = actor_critic(normalized_obs, rnn_states, action_mask=action_mask)

# sample actions from the distribution by default
actions = policy_outputs["actions"]
Expand All @@ -169,6 +170,7 @@ def max_frames_reached(frames):
last_render_start = render_frame(cfg, env, video_frames, num_episodes, last_render_start)

obs, rew, terminated, truncated, infos = env.step(actions)
action_mask = obs.pop("action_mask").to(device) if "action_mask" in obs else None
dones = make_dones(terminated, truncated)
infos = [{} for _ in range(env_info.num_agents)] if infos is None else infos

Expand Down
79 changes: 79 additions & 0 deletions sample_factory/envs/pettingzoo_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Gym env wrappers for PettingZoo -> Gymnasium transition.
"""

import gymnasium as gym


class PettingZooParallelEnv(gym.Env):
def __init__(self, env):
if not all_equal([env.observation_space(a) for a in env.possible_agents]):
raise ValueError("All observation spaces must be equal")

if not all_equal([env.action_space(a) for a in env.possible_agents]):
raise ValueError("All action spaces must be equal")

self.env = env
self.metadata = env.metadata
self.render_mode = env.render_mode if hasattr(env, "render_mode") else env.unwrapped.render_mode
self.observation_space = normalize_observation_space(env.observation_space(env.possible_agents[0]))
self.action_space = env.action_space(env.possible_agents[0])
self.num_agents = env.max_num_agents
self.is_multiagent = True

def reset(self, **kwargs):
obs, infos = self.env.reset(**kwargs)
obs = [normalize_observation(obs.get(a)) for a in self.env.possible_agents]
infos = [infos[a] if a in infos else dict(is_active=False) for a in self.env.possible_agents]
return obs, infos

def step(self, actions):
actions = dict(zip(self.env.possible_agents, actions))
obs, rewards, terminations, truncations, infos = self.env.step(actions)

if not self.env.agents:
obs, infos = self.env.reset()

obs = [normalize_observation(obs.get(a)) for a in self.env.possible_agents]
rewards = [rewards.get(a) for a in self.env.possible_agents]
terminations = [terminations.get(a) for a in self.env.possible_agents]
truncations = [truncations.get(a) for a in self.env.possible_agents]
infos = [normalize_info(infos[a], a) if a in infos else dict(is_active=False) for a in self.env.possible_agents]
return obs, rewards, terminations, truncations, infos

def render(self):
return self.env.render()

def close(self):
self.env.close()


def all_equal(l_) -> bool:
return all(v == l_[0] for v in l_)


def normalize_observation_space(obs_space):
"""Normalize observation space with the key "obs" that's specially handled as the main value."""
if isinstance(obs_space, gym.spaces.Dict) and "observation" in obs_space.spaces:
spaces = dict(obs_space.spaces)
spaces["obs"] = spaces["observation"]
del spaces["observation"]
obs_space = gym.spaces.Dict(spaces)

return obs_space


def normalize_observation(obs):
if isinstance(obs, dict) and "observation" in obs:
obs["obs"] = obs["observation"]
del obs["observation"]

return obs


def normalize_info(info, agent):
"""active_agent is available when using `turn_based_aec_to_parallel` of PettingZoo."""
if isinstance(info, dict) and "active_agent" in info:
info["is_active"] = info["active_agent"] == agent

return info
Loading

0 comments on commit abbc459

Please sign in to comment.