From abbc4591fcfa3f2cb20b98bb9b0f2a1ee83f47fa Mon Sep 17 00:00:00 2001 From: Naoyuki Kanezawa Date: Wed, 23 Oct 2024 14:14:25 +0700 Subject: [PATCH] Support PettingZoo Parallel API and action mask (#305) * support PettingZoo Parallel API and action mask * Fix pre-commit * Add pettingzoo to github test deps * add missing reset_global_context() * fix ContinuousActionDistribution to receive action_mask but ignore the value for now * exclude action_mask in default_make_actor_critic_func * add docs for action mask and pettingzoo env * improve doc * apply action_mask to epsilons * use original probs when all actions are masked * apply softmax with mask * remove unnecessary computations * fix to apply mask to probs * improve masked_softmax and masked_log_softmax for extreme cases --------- Co-authored-by: Aleksei Petrenko --- .github/workflows/test-ci.yml | 4 +- Makefile | 2 +- docs/07-advanced-topics/action-masking.md | 38 +++++++ .../09-environment-integrations/pettingzoo.md | 46 ++++++++ mkdocs.yml | 2 + .../algo/sampling/inference_worker.py | 5 +- .../algo/sampling/non_batched_sampling.py | 2 +- .../algo/utils/action_distributions.py | 52 +++++++-- sample_factory/algo/utils/context.py | 4 + sample_factory/enjoy.py | 4 +- sample_factory/envs/pettingzoo_envs.py | 79 ++++++++++++++ .../model/action_parameterization.py | 12 ++- sample_factory/model/actor_critic.py | 47 ++++++-- sample_factory/model/model_factory.py | 1 - setup.py | 5 +- sf_examples/enjoy_pettingzoo_env.py | 16 +++ sf_examples/train_pettingzoo_env.py | 100 ++++++++++++++++++ tests/algo/test_action_distributions.py | 27 ++++- tests/envs/pettingzoo/__init__.py | 0 tests/envs/pettingzoo/test_pettingzoo.py | 64 +++++++++++ tests/envs/utils.py | 26 +++-- 21 files changed, 493 insertions(+), 43 deletions(-) create mode 100644 docs/07-advanced-topics/action-masking.md create mode 100644 docs/09-environment-integrations/pettingzoo.md create mode 100644 sample_factory/envs/pettingzoo_envs.py create mode 100644 sf_examples/enjoy_pettingzoo_env.py create mode 100644 sf_examples/train_pettingzoo_env.py create mode 100644 tests/envs/pettingzoo/__init__.py create mode 100644 tests/envs/pettingzoo/test_pettingzoo.py diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 1519e50e9..48842bffe 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -31,7 +31,7 @@ jobs: - name: Install conda env & dependencies run: | conda install python=${{ matrix.python-version }} - pip install -e '.[atari, mujoco, envpool]' + pip install -e '.[atari, mujoco, envpool, pettingzoo]' conda list - name: Install test dependencies run: | @@ -75,7 +75,7 @@ jobs: - name: Install conda env & dependencies run: | conda install python=${{ matrix.python-version }} - pip install -e '.[atari, mujoco]' + pip install -e '.[atari, mujoco, pettingzoo]' conda list - name: Install test dependencies run: | diff --git a/Makefile b/Makefile index 5fe3d1996..b7b2cf847 100644 --- a/Makefile +++ b/Makefile @@ -46,7 +46,7 @@ check-codestyle: .PHONY: test test: - pytest -s --maxfail=2 + pytest -s --maxfail=2 -rA # ; echo "Tests finished. You might need to type 'reset' and press Enter to fix the terminal window" diff --git a/docs/07-advanced-topics/action-masking.md b/docs/07-advanced-topics/action-masking.md new file mode 100644 index 000000000..26a972eeb --- /dev/null +++ b/docs/07-advanced-topics/action-masking.md @@ -0,0 +1,38 @@ +# Action Masking + +Action masking is a technique used to restrict the set of actions available to an agent in certain states. This can be particularly useful in environments where some actions are invalid or undesirable in specific situations. See [paper](https://arxiv.org/abs/2006.14171) for more details. + +## Implementing Action Masking + +To implement action masking in your environment, you need to add an `action_mask` field to the observation dictionary returned by your environment. Here's how to do it: + +1. Define the action mask space in your environment's observation space +2. Generate and return the action mask in both `reset()` and `step()` methods + +Here's an example of a custom environment implementing action masking: + +```python +import gymnasium as gym +import numpy as np + +class CustomEnv(gym.Env): + def __init__(self, full_env_name, cfg, render_mode=None): + ... + self.observation_space = gym.spaces.Dict({ + "obs": gym.spaces.Box(low=0, high=1, shape=(3, 3, 2), dtype=np.int8), + "action_mask": gym.spaces.Box(low=0, high=1, shape=(9,), dtype=np.int8), + }) + self.action_space = gym.spaces.Discrete(9) + + def reset(self, **kwargs): + ... + # Initial action mask that allows all actions + action_mask = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1]) + return {"obs": obs, "action_mask": action_mask}, info + + def step(self, action): + ... + # Generate new action mask based on the current state + action_mask = np.array([1, 0, 0, 1, 1, 1, 0, 1, 1]) + return {"obs": obs, "action_mask": action_mask}, reward, terminated, truncated, info +``` diff --git a/docs/09-environment-integrations/pettingzoo.md b/docs/09-environment-integrations/pettingzoo.md new file mode 100644 index 000000000..0bdb533f2 --- /dev/null +++ b/docs/09-environment-integrations/pettingzoo.md @@ -0,0 +1,46 @@ +# PettingZoo + +[PettingZoo](https://pettingzoo.farama.org/) is a Python library for conducting research in multi-agent reinforcement learning. This guide explains how to use PettingZoo environments with Sample Factory. + +## Installation + +Install Sample Factory with PettingZoo dependencies with PyPI: + +```bash +pip install -e sample-factory[pettingzoo] +``` + +## Running Experiments + +Run PettingZoo experiments with the scripts in `sf_examples`. +The default parameters are not tuned for throughput. + +To train a model in the `tictactoe_v3` environment: + +``` +python -m sf_examples.train_pettingzoo_env --algo=APPO --env=tictactoe_v3 --experiment="Experiment Name" +``` + +To visualize the training results, use the `enjoy_pettingzoo_env` script: + +``` +python -m sf_examples.enjoy_pettingzoo_env --env=tictactoe_v3 --experiment="Experiment Name" +``` + +Currently, the scripts in `sf_examples` are set up for the `tictactoe_v3` environment. To use other PettingZoo environments, you'll need to modify the scripts or add your own as explained below. + +### Adding a new PettingZoo environment + +To add a new PettingZoo environment, follow the instructions from [Custom environments](../03-customization/custom-environments.md), with the additional step of wrapping your PettingZoo environment with `sample_factory.envs.pettingzoo_envs.PettingZooParallelEnv`. + +Here's an example of how to create a factory function for a PettingZoo environment: + +```python +from sample_factory.envs.pettingzoo_envs import PettingZooParallelEnv +import some_pettingzoo_env # Import your desired PettingZoo environment + +def make_pettingzoo_env(full_env_name, cfg=None, env_config=None, render_mode=None): + return PettingZooParallelEnv(some_pettingzoo_env.parallel_env(render_mode=render_mode)) +``` + +Note: Sample Factory supports only the [Parallel API](https://pettingzoo.farama.org/api/parallel/) of PettingZoo. If your environment uses the AEC API, you can convert it to Parallel API using `pettingzoo.utils.conversions.aec_to_parallel` or `pettingzoo.utils.conversions.turn_based_aec_to_parallel`. Be aware that these conversions have some limitations. For more details, refer to the [PettingZoo documentation](https://pettingzoo.farama.org/api/wrappers/pz_wrappers/). \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index a6a807203..b9e4544af 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -156,6 +156,7 @@ nav: - 07-advanced-topics/passing-info.md - 07-advanced-topics/observer.md - 07-advanced-topics/profiling.md + - 07-advanced-topics/action-masking.md - Miscellaneous: - 08-miscellaneous/tests.md - 08-miscellaneous/v1-to-v2.md @@ -170,6 +171,7 @@ nav: - 09-environment-integrations/nethack.md - 09-environment-integrations/brax.md - 09-environment-integrations/swarm-rl.md + - 09-environment-integrations/pettingzoo.md - Huggingface Integration: - 10-huggingface/huggingface.md - Release Notes: diff --git a/sample_factory/algo/sampling/inference_worker.py b/sample_factory/algo/sampling/inference_worker.py index 1fede875b..3a122fa84 100644 --- a/sample_factory/algo/sampling/inference_worker.py +++ b/sample_factory/algo/sampling/inference_worker.py @@ -321,11 +321,14 @@ def _handle_policy_steps(self, timing): if actor_critic.training: actor_critic.eval() # need to call this because we can be in serial mode + action_mask = ( + ensure_torch_tensor(obs.pop("action_mask")).to(self.device) if "action_mask" in obs else None + ) normalized_obs = prepare_and_normalize_obs(actor_critic, obs) rnn_states = ensure_torch_tensor(rnn_states).to(self.device).float() with timing.add_time("forward"): - policy_outputs = actor_critic(normalized_obs, rnn_states) + policy_outputs = actor_critic(normalized_obs, rnn_states, action_mask=action_mask) policy_outputs["policy_version"] = torch.empty([num_samples]).fill_(self.param_client.policy_version) with timing.add_time("prepare_outputs"): diff --git a/sample_factory/algo/sampling/non_batched_sampling.py b/sample_factory/algo/sampling/non_batched_sampling.py index ebda139ca..54b26b85f 100644 --- a/sample_factory/algo/sampling/non_batched_sampling.py +++ b/sample_factory/algo/sampling/non_batched_sampling.py @@ -435,7 +435,7 @@ def _reset(self): log.info("Decorrelating experience for %d frames...", decorrelate_steps) for decorrelate_step in range(decorrelate_steps): - actions = [e.action_space.sample() for _ in range(self.num_agents)] + actions = [e.action_space.sample(obs.get("action_mask")) for obs in observations] observations, rew, terminated, truncated, info = e.step(actions) for agent_i, obs in enumerate(observations): diff --git a/sample_factory/algo/utils/action_distributions.py b/sample_factory/algo/utils/action_distributions.py index 083b8504e..b279aeb3a 100644 --- a/sample_factory/algo/utils/action_distributions.py +++ b/sample_factory/algo/utils/action_distributions.py @@ -42,7 +42,7 @@ def is_continuous_action_space(action_space: ActionSpace) -> bool: return isinstance(action_space, gym.spaces.Box) -def get_action_distribution(action_space, raw_logits): +def get_action_distribution(action_space, raw_logits, action_mask=None): """ Create the distribution object based on provided action space and unprocessed logits. :param action_space: Gym action space object @@ -52,9 +52,9 @@ def get_action_distribution(action_space, raw_logits): assert calc_num_action_parameters(action_space) == raw_logits.shape[-1] if isinstance(action_space, gym.spaces.Discrete): - return CategoricalActionDistribution(raw_logits) + return CategoricalActionDistribution(raw_logits, action_mask) elif isinstance(action_space, gym.spaces.Tuple): - return TupleActionDistribution(action_space, logits_flat=raw_logits) + return TupleActionDistribution(action_space, logits_flat=raw_logits, action_mask=action_mask) elif isinstance(action_space, gym.spaces.Box): return ContinuousActionDistribution(params=raw_logits) else: @@ -81,35 +81,65 @@ def argmax_actions(distribution): raise NotImplementedError(f"Action distribution type {type(distribution)} does not support argmax!") +def masked_softmax(logits, mask): + # Mask out the invalid logits by adding a large negative number (-1e9) + logits = logits + (mask == 0) * -1e9 + result = functional.softmax(logits, dim=-1) + result = result * mask + result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) + return result + + +def masked_log_softmax(logits, mask): + logits = logits + (mask == 0) * -1e9 + return functional.log_softmax(logits, dim=-1) + + # noinspection PyAbstractClass class CategoricalActionDistribution: - def __init__(self, raw_logits): + def __init__(self, raw_logits, action_mask=None): """ Ctor. :param raw_logits: unprocessed logits, typically an output of a fully-connected layer """ self.raw_logits = raw_logits + self.action_mask = action_mask self.log_p = self.p = None @property def probs(self): if self.p is None: - self.p = functional.softmax(self.raw_logits, dim=-1) + if self.action_mask is not None: + self.p = masked_softmax(self.raw_logits, self.action_mask) + else: + self.p = functional.softmax(self.raw_logits, dim=-1) return self.p @property def log_probs(self): if self.log_p is None: - self.log_p = functional.log_softmax(self.raw_logits, dim=-1) + if self.action_mask is not None: + self.log_p = masked_log_softmax(self.raw_logits, self.action_mask) + else: + self.log_p = functional.log_softmax(self.raw_logits, dim=-1) return self.log_p def sample_gumbel(self): - sample = torch.argmax(self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_(), -1) + probs = self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_() + if self.action_mask is not None: + probs = probs * self.action_mask + sample = torch.argmax(probs, -1) return sample def sample(self): - samples = torch.multinomial(self.probs, 1, True) + probs = self.probs + if self.action_mask is not None: + all_zero = (probs.sum(dim=-1) == 0).unsqueeze(-1) + epsilons = torch.full_like(probs, 1e-6) + probs = torch.where(all_zero, epsilons, probs) # ensure sum of probabilities is non-zero + + samples = torch.multinomial(probs, 1, True) return samples def log_prob(self, value): @@ -181,16 +211,18 @@ class TupleActionDistribution: """ - def __init__(self, action_space, logits_flat): + def __init__(self, action_space, logits_flat, action_mask=None): self.logit_lengths = [calc_num_action_parameters(s) for s in action_space.spaces] self.split_logits = torch.split(logits_flat, self.logit_lengths, dim=1) self.action_lengths = [calc_num_actions(s) for s in action_space.spaces] + self.action_mask = action_mask assert len(self.split_logits) == len(action_space.spaces) self.distributions = [] for i, space in enumerate(action_space.spaces): - self.distributions.append(get_action_distribution(space, self.split_logits[i])) + action_mask = self.action_mask[i] if self.action_mask is not None else None + self.distributions.append(get_action_distribution(space, self.split_logits[i], action_mask)) @staticmethod def _flatten_actions(list_of_action_batches): diff --git a/sample_factory/algo/utils/context.py b/sample_factory/algo/utils/context.py index 8eb8f22a8..50426d199 100644 --- a/sample_factory/algo/utils/context.py +++ b/sample_factory/algo/utils/context.py @@ -26,6 +26,10 @@ def set_global_context(ctx: SampleFactoryContext): def reset_global_context(): + """ + Most useful in tests, call this after any part of the global context has been modified + by a test in any way. + """ global GLOBAL_CONTEXT GLOBAL_CONTEXT = SampleFactoryContext() diff --git a/sample_factory/enjoy.py b/sample_factory/enjoy.py index 1bd5e475e..9d2af9b25 100644 --- a/sample_factory/enjoy.py +++ b/sample_factory/enjoy.py @@ -136,6 +136,7 @@ def max_frames_reached(frames): reward_list = [] obs, infos = env.reset() + action_mask = obs.pop("action_mask").to(device) if "action_mask" in obs else None rnn_states = torch.zeros([env.num_agents, get_rnn_size(cfg)], dtype=torch.float32, device=device) episode_reward = None finished_episode = [False for _ in range(env.num_agents)] @@ -149,7 +150,7 @@ def max_frames_reached(frames): if not cfg.no_render: visualize_policy_inputs(normalized_obs) - policy_outputs = actor_critic(normalized_obs, rnn_states) + policy_outputs = actor_critic(normalized_obs, rnn_states, action_mask=action_mask) # sample actions from the distribution by default actions = policy_outputs["actions"] @@ -169,6 +170,7 @@ def max_frames_reached(frames): last_render_start = render_frame(cfg, env, video_frames, num_episodes, last_render_start) obs, rew, terminated, truncated, infos = env.step(actions) + action_mask = obs.pop("action_mask").to(device) if "action_mask" in obs else None dones = make_dones(terminated, truncated) infos = [{} for _ in range(env_info.num_agents)] if infos is None else infos diff --git a/sample_factory/envs/pettingzoo_envs.py b/sample_factory/envs/pettingzoo_envs.py new file mode 100644 index 000000000..2571b40fc --- /dev/null +++ b/sample_factory/envs/pettingzoo_envs.py @@ -0,0 +1,79 @@ +""" +Gym env wrappers for PettingZoo -> Gymnasium transition. +""" + +import gymnasium as gym + + +class PettingZooParallelEnv(gym.Env): + def __init__(self, env): + if not all_equal([env.observation_space(a) for a in env.possible_agents]): + raise ValueError("All observation spaces must be equal") + + if not all_equal([env.action_space(a) for a in env.possible_agents]): + raise ValueError("All action spaces must be equal") + + self.env = env + self.metadata = env.metadata + self.render_mode = env.render_mode if hasattr(env, "render_mode") else env.unwrapped.render_mode + self.observation_space = normalize_observation_space(env.observation_space(env.possible_agents[0])) + self.action_space = env.action_space(env.possible_agents[0]) + self.num_agents = env.max_num_agents + self.is_multiagent = True + + def reset(self, **kwargs): + obs, infos = self.env.reset(**kwargs) + obs = [normalize_observation(obs.get(a)) for a in self.env.possible_agents] + infos = [infos[a] if a in infos else dict(is_active=False) for a in self.env.possible_agents] + return obs, infos + + def step(self, actions): + actions = dict(zip(self.env.possible_agents, actions)) + obs, rewards, terminations, truncations, infos = self.env.step(actions) + + if not self.env.agents: + obs, infos = self.env.reset() + + obs = [normalize_observation(obs.get(a)) for a in self.env.possible_agents] + rewards = [rewards.get(a) for a in self.env.possible_agents] + terminations = [terminations.get(a) for a in self.env.possible_agents] + truncations = [truncations.get(a) for a in self.env.possible_agents] + infos = [normalize_info(infos[a], a) if a in infos else dict(is_active=False) for a in self.env.possible_agents] + return obs, rewards, terminations, truncations, infos + + def render(self): + return self.env.render() + + def close(self): + self.env.close() + + +def all_equal(l_) -> bool: + return all(v == l_[0] for v in l_) + + +def normalize_observation_space(obs_space): + """Normalize observation space with the key "obs" that's specially handled as the main value.""" + if isinstance(obs_space, gym.spaces.Dict) and "observation" in obs_space.spaces: + spaces = dict(obs_space.spaces) + spaces["obs"] = spaces["observation"] + del spaces["observation"] + obs_space = gym.spaces.Dict(spaces) + + return obs_space + + +def normalize_observation(obs): + if isinstance(obs, dict) and "observation" in obs: + obs["obs"] = obs["observation"] + del obs["observation"] + + return obs + + +def normalize_info(info, agent): + """active_agent is available when using `turn_based_aec_to_parallel` of PettingZoo.""" + if isinstance(info, dict) and "active_agent" in info: + info["is_active"] = info["active_agent"] == agent + + return info diff --git a/sample_factory/model/action_parameterization.py b/sample_factory/model/action_parameterization.py index c11cf7c62..130289830 100644 --- a/sample_factory/model/action_parameterization.py +++ b/sample_factory/model/action_parameterization.py @@ -30,10 +30,12 @@ def __init__(self, cfg, core_out_size, action_space): num_action_outputs = calc_num_action_parameters(action_space) self.distribution_linear = nn.Linear(core_out_size, num_action_outputs) - def forward(self, actor_core_output): + def forward(self, actor_core_output, action_mask=None): """Just forward the FC layer and generate the distribution object.""" action_distribution_params = self.distribution_linear(actor_core_output) - action_distribution = get_action_distribution(self.action_space, raw_logits=action_distribution_params) + action_distribution = get_action_distribution( + self.action_space, raw_logits=action_distribution_params, action_mask=action_mask + ) return action_distribution_params, action_distribution @@ -58,7 +60,7 @@ def __init__(self, cfg, core_out_size, action_space): initial_stddev.fill_(math.log(self.cfg.initial_stddev)) self.learned_stddev = nn.Parameter(initial_stddev, requires_grad=True) - def forward(self, actor_core_output: Tensor): + def forward(self, actor_core_output: Tensor, action_mask=None): action_means = self.distribution_linear(actor_core_output) if self.tanh_scale > 0: # scale the action means to be in the range [-tanh_scale, tanh_scale] @@ -68,5 +70,7 @@ def forward(self, actor_core_output: Tensor): batch_size = action_means.shape[0] action_stddevs = self.learned_stddev.repeat(batch_size, 1) action_distribution_params = torch.cat((action_means, action_stddevs), dim=1) - action_distribution = get_action_distribution(self.action_space, raw_logits=action_distribution_params) + action_distribution = get_action_distribution( + self.action_space, raw_logits=action_distribution_params, action_mask=action_mask + ) return action_distribution_params, action_distribution diff --git a/sample_factory/model/actor_critic.py b/sample_factory/model/actor_critic.py index 39926055d..e1fe81719 100644 --- a/sample_factory/model/actor_critic.py +++ b/sample_factory/model/actor_critic.py @@ -2,6 +2,7 @@ from typing import Dict, Optional +import gymnasium as gym import torch from torch import Tensor, nn from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence @@ -121,10 +122,14 @@ def forward_head(self, normalized_obs_dict: Dict[str, Tensor]) -> Tensor: def forward_core(self, head_output, rnn_states): raise NotImplementedError() - def forward_tail(self, core_output, values_only: bool, sample_actions: bool) -> TensorDict: + def forward_tail( + self, core_output, values_only: bool, sample_actions: bool, action_mask: Optional[Tensor] = None + ) -> TensorDict: raise NotImplementedError() - def forward(self, normalized_obs_dict, rnn_states, values_only: bool = False) -> TensorDict: + def forward( + self, normalized_obs_dict, rnn_states, values_only: bool = False, action_mask: Optional[Tensor] = None + ) -> TensorDict: raise NotImplementedError() @@ -160,7 +165,9 @@ def forward_core(self, head_output: Tensor, rnn_states): x, new_rnn_states = self.core(head_output, rnn_states) return x, new_rnn_states - def forward_tail(self, core_output, values_only: bool, sample_actions: bool) -> TensorDict: + def forward_tail( + self, core_output, values_only: bool, sample_actions: bool, action_mask: Optional[Tensor] = None + ) -> TensorDict: decoder_output = self.decoder(core_output) values = self.critic_linear(decoder_output).squeeze() @@ -168,7 +175,9 @@ def forward_tail(self, core_output, values_only: bool, sample_actions: bool) -> if values_only: return result - action_distribution_params, self.last_action_distribution = self.action_parameterization(decoder_output) + action_distribution_params, self.last_action_distribution = self.action_parameterization( + decoder_output, action_mask + ) # `action_logits` is not the best name here, better would be "action distribution parameters" result["action_logits"] = action_distribution_params @@ -176,10 +185,12 @@ def forward_tail(self, core_output, values_only: bool, sample_actions: bool) -> self._maybe_sample_actions(sample_actions, result) return result - def forward(self, normalized_obs_dict, rnn_states, values_only=False) -> TensorDict: + def forward( + self, normalized_obs_dict, rnn_states, values_only=False, action_mask: Optional[Tensor] = None + ) -> TensorDict: x = self.forward_head(normalized_obs_dict) x, new_rnn_states = self.forward_core(x, rnn_states) - result = self.forward_tail(x, values_only, sample_actions=True) + result = self.forward_tail(x, values_only, sample_actions=True, action_mask=action_mask) result["new_rnn_states"] = new_rnn_states return result @@ -276,7 +287,9 @@ def forward_head(self, normalized_obs_dict: Dict): def forward_core(self, head_output, rnn_states): return self.core_func(head_output, rnn_states) - def forward_tail(self, core_output, values_only: bool, sample_actions: bool) -> TensorDict: + def forward_tail( + self, core_output, values_only: bool, sample_actions: bool, action_mask: Optional[Tensor] = None + ) -> TensorDict: core_outputs = core_output.chunk(len(self.cores), dim=1) # second core output corresponds to the critic @@ -290,17 +303,21 @@ def forward_tail(self, core_output, values_only: bool, sample_actions: bool) -> # first core output corresponds to the actor actor_decoder_output = self.actor_decoder(core_outputs[0]) - action_distribution_params, self.last_action_distribution = self.action_parameterization(actor_decoder_output) + action_distribution_params, self.last_action_distribution = self.action_parameterization( + actor_decoder_output, action_mask + ) result["action_logits"] = action_distribution_params self._maybe_sample_actions(sample_actions, result) return result - def forward(self, normalized_obs_dict, rnn_states, values_only=False) -> TensorDict: + def forward( + self, normalized_obs_dict, rnn_states, values_only=False, action_mask: Optional[Tensor] = None + ) -> TensorDict: x = self.forward_head(normalized_obs_dict) x, new_rnn_states = self.forward_core(x, rnn_states) - result = self.forward_tail(x, values_only, sample_actions=True) + result = self.forward_tail(x, values_only, sample_actions=True, action_mask=action_mask) result["new_rnn_states"] = new_rnn_states return result @@ -309,6 +326,7 @@ def default_make_actor_critic_func(cfg: Config, obs_space: ObsSpace, action_spac from sample_factory.algo.utils.context import global_model_factory model_factory = global_model_factory() + obs_space = obs_space_without_action_mask(obs_space) if cfg.actor_critic_share_weights: return ActorCriticSharedWeights(model_factory, obs_space, action_space, cfg) @@ -322,3 +340,12 @@ def create_actor_critic(cfg: Config, obs_space: ObsSpace, action_space: ActionSp make_actor_critic_func = global_model_factory().make_actor_critic_func return make_actor_critic_func(cfg, obs_space, action_space) + + +def obs_space_without_action_mask(obs_space: ObsSpace) -> ObsSpace: + if isinstance(obs_space, gym.spaces.Dict) and "action_mask" in obs_space.spaces: + spaces = obs_space.spaces.copy() + del spaces["action_mask"] + obs_space = gym.spaces.Dict(spaces) + + return obs_space diff --git a/sample_factory/model/model_factory.py b/sample_factory/model/model_factory.py index bad21dd0f..8ad823584 100644 --- a/sample_factory/model/model_factory.py +++ b/sample_factory/model/model_factory.py @@ -19,7 +19,6 @@ def __init__(self): Optional custom functions for creating parts of the model (encoders, decoders, etc.), or even overriding the entire actor-critic with a custom model. """ - self.make_actor_critic_func: MakeActorCriticFunc = default_make_actor_critic_func # callables user can specify to generate parts of the policy diff --git a/setup.py b/setup.py index 6097520ee..20fa95b1c 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ "debugpy ~= 1.6", ] _envpool_deps = ["envpool"] +_pettingzoo_deps = ["pettingzoo[classic]"] _docs_deps = [ "mkdocs-material", @@ -80,11 +81,13 @@ def is_macos(): "dev": ["black", "isort>=5.12", "pytest<8.0", "flake8", "pre-commit", "twine"] + _docs_deps + _atari_deps - + _mujoco_deps, + + _mujoco_deps + + _pettingzoo_deps, "atari": _atari_deps, "envpool": _envpool_deps, "mujoco": _mujoco_deps, "nethack": _nethack_deps, + "pettingzoo": _pettingzoo_deps, "vizdoom": ["vizdoom<2.0", "gymnasium[classic_control]"], # "dmlab": ["dm_env"], <-- these are just auxiliary packages, the main package has to be built from sources }, diff --git a/sf_examples/enjoy_pettingzoo_env.py b/sf_examples/enjoy_pettingzoo_env.py new file mode 100644 index 000000000..24408cea0 --- /dev/null +++ b/sf_examples/enjoy_pettingzoo_env.py @@ -0,0 +1,16 @@ +import sys + +from sample_factory.enjoy import enjoy +from sf_examples.train_pettingzoo_env import parse_custom_args, register_custom_components + + +def main(): # pragma: no cover + """Script entry point.""" + register_custom_components() + cfg = parse_custom_args(evaluation=True) + status = enjoy(cfg) + return status + + +if __name__ == "__main__": # pragma: no cover + sys.exit(main()) diff --git a/sf_examples/train_pettingzoo_env.py b/sf_examples/train_pettingzoo_env.py new file mode 100644 index 000000000..cc5a402d4 --- /dev/null +++ b/sf_examples/train_pettingzoo_env.py @@ -0,0 +1,100 @@ +""" +An example that shows how to use SampleFactory with a PettingZoo env. + +Example command line for tictactoe_v3: +python -m sf_examples.train_pettingzoo_env --algo=APPO --use_rnn=False --num_envs_per_worker=20 --policy_workers_per_policy=2 --recurrence=1 --with_vtrace=False --batch_size=512 --save_every_sec=10 --experiment_summaries_interval=10 --experiment=example_pettingzoo_tictactoe_v3 --env=tictactoe_v3 +python -m sf_examples.enjoy_pettingzoo_env --algo=APPO --experiment=example_pettingzoo_tictactoe_v3 --env=tictactoe_v3 + +""" + +import sys +from typing import List, Optional + +import gymnasium as gym +import torch +from pettingzoo.classic import tictactoe_v3 +from pettingzoo.utils import turn_based_aec_to_parallel +from torch import Tensor, nn + +from sample_factory.algo.utils.context import global_model_factory +from sample_factory.algo.utils.torch_utils import calc_num_elements +from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args +from sample_factory.envs.env_utils import register_env +from sample_factory.envs.pettingzoo_envs import PettingZooParallelEnv +from sample_factory.model.encoder import Encoder +from sample_factory.model.model_utils import create_mlp, nonlinearity +from sample_factory.train import run_rl +from sample_factory.utils.attr_dict import AttrDict +from sample_factory.utils.typing import Config, ObsSpace + + +class CustomConvEncoder(Encoder): + def __init__(self, cfg: Config, obs_space: ObsSpace): + super().__init__(cfg) + main_obs_space = obs_space["obs"] + input_channels = main_obs_space.shape[0] + conv_filters = [[input_channels, 32, 2, 1], [32, 64, 2, 1], [64, 128, 2, 1]] + activation = nonlinearity(self.cfg) + extra_mlp_layers = cfg.encoder_conv_mlp_layers + enc = ConvEncoderImpl(main_obs_space.shape, conv_filters, extra_mlp_layers, activation) + self.enc = torch.jit.script(enc) + self.encoder_out_size = calc_num_elements(self.enc, main_obs_space.shape) + + def get_out_size(self): + return self.encoder_out_size + + def forward(self, obs_dict): + main_obs = obs_dict["obs"] + return self.enc(main_obs) + + +class ConvEncoderImpl(nn.Module): + def __init__(self, obs_shape: AttrDict, conv_filters: List, extra_mlp_layers: List[int], activation: nn.Module): + super().__init__() + conv_layers = [] + + for layer in conv_filters: + inp_ch, out_ch, filter_size, padding = layer + conv_layers.append(nn.Conv2d(inp_ch, out_ch, filter_size, padding=padding)) + conv_layers.append(activation) + + self.conv_head = nn.Sequential(*conv_layers) + self.conv_head_out_size = calc_num_elements(self.conv_head, obs_shape) + self.mlp_layers = create_mlp(extra_mlp_layers, self.conv_head_out_size, activation) + + def forward(self, obs: Tensor) -> Tensor: + x = self.conv_head(obs) + x = x.contiguous().view(-1, self.conv_head_out_size) + x = self.mlp_layers(x) + return x + + +def make_pettingzoo_env(full_env_name, cfg=None, env_config=None, render_mode: Optional[str] = None): + return PettingZooParallelEnv(turn_based_aec_to_parallel(tictactoe_v3.env(render_mode=render_mode))) + + +def make_custom_encoder(cfg: Config, obs_space: ObsSpace) -> Encoder: + return CustomConvEncoder(cfg, obs_space) + + +def register_custom_components(): + register_env("tictactoe_v3", make_pettingzoo_env) + global_model_factory().register_encoder_factory(make_custom_encoder) + + +def parse_custom_args(argv=None, evaluation=False): + parser, cfg = parse_sf_args(argv=argv, evaluation=evaluation) + cfg = parse_full_cfg(parser, argv) + return cfg + + +def main(): + """Script entry point.""" + register_custom_components() + cfg = parse_custom_args() + status = run_rl(cfg) + return status + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/algo/test_action_distributions.py b/tests/algo/test_action_distributions.py index 54834f750..3d1a4ac11 100644 --- a/tests/algo/test_action_distributions.py +++ b/tests/algo/test_action_distributions.py @@ -19,17 +19,19 @@ class TestActionDistributions: @pytest.mark.parametrize("gym_space", [gym.spaces.Discrete(3)]) @pytest.mark.parametrize("batch_size", [128]) - def test_simple_distribution(self, gym_space, batch_size): + @pytest.mark.parametrize("has_action_mask", [False, True]) + def test_simple_distribution(self, gym_space, batch_size, has_action_mask): simple_action_space = gym_space simple_num_logits = calc_num_action_parameters(simple_action_space) assert simple_num_logits == simple_action_space.n + expected_actions, action_mask = generate_expected_actions(simple_action_space.n, batch_size, has_action_mask) simple_logits = torch.rand(batch_size, simple_num_logits) - simple_action_distribution = get_action_distribution(simple_action_space, simple_logits) + simple_action_distribution = get_action_distribution(simple_action_space, simple_logits, action_mask) simple_actions = simple_action_distribution.sample() assert list(simple_actions.shape) == [batch_size, 1] - assert all(0 <= a < simple_action_space.n for a in simple_actions) + assert all(torch.isin(a, expected_actions) for a in simple_actions) @pytest.mark.parametrize("gym_space", [gym.spaces.Discrete(3)]) @pytest.mark.parametrize("batch_size", [128]) @@ -91,7 +93,8 @@ def test_gumbel_trick(self, gym_space, batch_size, device_type): @pytest.mark.parametrize("num_spaces", [1, 4]) @pytest.mark.parametrize("gym_space", [gym.spaces.Discrete(1), gym.spaces.Discrete(3)]) @pytest.mark.parametrize("batch_size", [128]) - def test_tuple_distribution(self, num_spaces, gym_space, batch_size): + @pytest.mark.parametrize("has_action_mask", [False, True]) + def test_tuple_distribution(self, num_spaces, gym_space, batch_size, has_action_mask): spaces = [gym_space for _ in range(num_spaces)] action_space = gym.spaces.Tuple(spaces) @@ -100,10 +103,13 @@ def test_tuple_distribution(self, num_spaces, gym_space, batch_size): assert num_logits == sum(s.n for s in action_space.spaces) - action_distribution = get_action_distribution(action_space, logits) + expected_actions, action_mask = generate_expected_actions(gym_space.n, batch_size, has_action_mask) + action_mask = action_mask.repeat(num_spaces, 1) if action_mask is not None else None + action_distribution = get_action_distribution(action_space, logits, action_mask) tuple_actions = action_distribution.sample() assert list(tuple_actions.shape) == [batch_size, num_spaces] + assert all(torch.isin(a, expected_actions) for actions in tuple_actions for a in actions) log_probs = action_distribution.log_prob(tuple_actions) assert list(log_probs.shape) == [batch_size] @@ -218,3 +224,14 @@ def test_tuple_action_distribution(spaces, sizes): assert actions.size() == (BATCH_SIZE, num_actions) assert action_log_probs.size() == (BATCH_SIZE,) + + +def generate_expected_actions(action_space_size, batch_size, has_action_mask): + if has_action_mask: + expected_actions = torch.tensor([i for i in range(action_space_size) if i % 2 == 0]) + action_mask = torch.tensor([[1 if i in expected_actions else 0 for i in range(action_space_size)]] * batch_size) + else: + expected_actions = torch.tensor(range(action_space_size)) + action_mask = None + + return expected_actions, action_mask diff --git a/tests/envs/pettingzoo/__init__.py b/tests/envs/pettingzoo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/envs/pettingzoo/test_pettingzoo.py b/tests/envs/pettingzoo/test_pettingzoo.py new file mode 100644 index 000000000..9f8ff805a --- /dev/null +++ b/tests/envs/pettingzoo/test_pettingzoo.py @@ -0,0 +1,64 @@ +import shutil +from os.path import isdir + +import pytest + +from sample_factory.algo.utils.context import reset_global_context +from sample_factory.algo.utils.misc import ExperimentStatus +from sample_factory.train import run_rl +from sample_factory.utils.utils import log +from sf_examples.train_pettingzoo_env import make_pettingzoo_env, parse_custom_args, register_custom_components +from tests.envs.utils import eval_env_performance +from tests.utils import clean_test_dir + + +class TestPettingZooEnv: + @pytest.fixture(scope="class", autouse=True) + def register_pettingzoo_fixture(self): + register_custom_components() + yield # this is where the actual test happens + reset_global_context() + + # noinspection PyUnusedLocal + @staticmethod + def make_env(env_config): + return make_pettingzoo_env("tictactoe_v3", cfg=parse_custom_args(argv=["--algo=APPO", "--env=tictactoe_v3"])) + + def test_pettingzoo_performance(self): + eval_env_performance(self.make_env, "pettingzoo") + + @staticmethod + def _run_test_env( + env: str = "tictactoe_v3", + num_workers: int = 2, + train_steps: int = 512, + batched_sampling: bool = False, + serial_mode: bool = True, + async_rl: bool = False, + batch_size: int = 256, + ): + log.debug(f"Testing with parameters {locals()}...") + assert train_steps > batch_size, "We need sufficient number of steps to accumulate at least one batch" + + experiment_name = "test_" + env + + cfg = parse_custom_args(argv=["--algo=APPO", f"--env={env}", f"--experiment={experiment_name}"]) + cfg.serial_mode = serial_mode + cfg.async_rl = async_rl + cfg.batched_sampling = batched_sampling + cfg.num_workers = num_workers + cfg.train_for_env_steps = train_steps + cfg.batch_size = batch_size + cfg.decorrelate_envs_on_one_worker = False + cfg.seed = 0 + cfg.device = "cpu" + + directory = clean_test_dir(cfg) + status = run_rl(cfg) + assert status == ExperimentStatus.SUCCESS + assert isdir(directory) + shutil.rmtree(directory, ignore_errors=True) + + @pytest.mark.parametrize("batched_sampling", [False, True]) + def test_basic_envs(self, batched_sampling): + self._run_test_env(batched_sampling=batched_sampling) diff --git a/tests/envs/utils.py b/tests/envs/utils.py index 9b36f8148..78c321938 100644 --- a/tests/envs/utils.py +++ b/tests/envs/utils.py @@ -11,6 +11,8 @@ def eval_env_performance(make_env, env_type, verbose=False, eval_frames=10_000): with t.timeit("init"): env = make_env(AttrDict({"worker_index": 0, "vector_index": 0})) total_num_frames, frames = eval_frames, 0 + num_agents = env.num_agents if hasattr(env, "num_agents") else 1 + is_multiagent = env.is_multiagent if hasattr(env, "is_multiagent") else num_agents > 1 with t.timeit("first_reset"): env.reset() @@ -22,7 +24,7 @@ def eval_env_performance(make_env, env_type, verbose=False, eval_frames=10_000): done = False start_reset = time.time() - env.reset() + obs, info = env.reset() t.reset += time.time() - start_reset num_resets += 1 @@ -33,13 +35,25 @@ def eval_env_performance(make_env, env_type, verbose=False, eval_frames=10_000): env.render() time.sleep(1.0 / 40) - obs, rew, terminated, truncated, info = env.step(env.action_space.sample()) - done = terminated | truncated - if verbose: - log.info("Received reward %.3f", rew) + if is_multiagent: + action_mask = [o.get("action_mask") if isinstance(o, dict) else None for o in obs] + action = [env.action_space.sample(m) for m in action_mask] + else: + action_mask = obs.get("action_mask") if isinstance(obs, dict) else None + action = env.action_space.sample(action_mask) + + obs, rew, terminated, truncated, info = env.step(action) + + if is_multiagent: + done = all(a | b for a, b in zip(terminated, truncated)) + else: + done = terminated | truncated + info = [info] + if verbose: + log.info("Received reward %.3f", rew) t.step += time.time() - start_step - frames += num_env_steps([info]) + frames += num_env_steps(info) fps = total_num_frames / t.experience log.debug("%s performance:", env_type)