Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PettingZoo Parallel API and action mask #305

Merged
Merged
4 changes: 2 additions & 2 deletions .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- name: Install conda env & dependencies
run: |
conda install python=${{ matrix.python-version }}
pip install -e '.[atari, mujoco, envpool]'
pip install -e '.[atari, mujoco, envpool, pettingzoo]'
conda list
- name: Install test dependencies
run: |
Expand Down Expand Up @@ -75,7 +75,7 @@ jobs:
- name: Install conda env & dependencies
run: |
conda install python=${{ matrix.python-version }}
pip install -e '.[atari, mujoco]'
pip install -e '.[atari, mujoco, pettingzoo]'
conda list
- name: Install test dependencies
run: |
Expand Down
5 changes: 4 additions & 1 deletion sample_factory/algo/sampling/inference_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,14 @@ def _handle_policy_steps(self, timing):
if actor_critic.training:
actor_critic.eval() # need to call this because we can be in serial mode

action_mask = (
ensure_torch_tensor(obs.pop("action_mask")).to(self.device) if "action_mask" in obs else None
)
normalized_obs = prepare_and_normalize_obs(actor_critic, obs)
rnn_states = ensure_torch_tensor(rnn_states).to(self.device).float()

with timing.add_time("forward"):
policy_outputs = actor_critic(normalized_obs, rnn_states)
policy_outputs = actor_critic(normalized_obs, rnn_states, action_mask=action_mask)
policy_outputs["policy_version"] = torch.empty([num_samples]).fill_(self.param_client.policy_version)

with timing.add_time("prepare_outputs"):
Expand Down
2 changes: 1 addition & 1 deletion sample_factory/algo/sampling/non_batched_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def _reset(self):

log.info("Decorrelating experience for %d frames...", decorrelate_steps)
for decorrelate_step in range(decorrelate_steps):
actions = [e.action_space.sample() for _ in range(self.num_agents)]
actions = [e.action_space.sample(obs.get("action_mask")) for obs in observations]
observations, rew, terminated, truncated, info = e.step(actions)

for agent_i, obs in enumerate(observations):
Expand Down
57 changes: 41 additions & 16 deletions sample_factory/algo/utils/action_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,26 @@ def get_action_distribution(action_space, raw_logits):
raise NotImplementedError(f"Action space type {type(action_space)} not supported!")


def sample_actions_log_probs(distribution):
def sample_actions_log_probs(distribution, action_mask=None):
if isinstance(distribution, TupleActionDistribution):
return distribution.sample_actions_log_probs()
return distribution.sample_actions_log_probs(action_mask)
else:
actions = distribution.sample()
if isinstance(distribution, ContinuousActionDistribution):
actions = distribution.sample()
nkzawa marked this conversation as resolved.
Show resolved Hide resolved
else:
actions = distribution.sample(action_mask)
log_prob_actions = distribution.log_prob(actions)
return actions, log_prob_actions


def argmax_actions(distribution):
def argmax_actions(distribution, action_mask=None):
if isinstance(distribution, TupleActionDistribution):
return distribution.argmax()
return distribution.argmax(action_mask)
elif hasattr(distribution, "probs"):
return torch.argmax(distribution.probs, dim=-1)
probs = distribution.probs
if action_mask is not None:
probs = probs * action_mask
return torch.argmax(probs, dim=-1)
elif hasattr(distribution, "means"):
return distribution.means
else:
Expand Down Expand Up @@ -104,12 +110,22 @@ def log_probs(self):
self.log_p = functional.log_softmax(self.raw_logits, dim=-1)
return self.log_p

def sample_gumbel(self):
sample = torch.argmax(self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_(), -1)
def sample_gumbel(self, action_mask=None):
probs = self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_()
if action_mask is not None:
probs = probs * action_mask
sample = torch.argmax(probs, -1)
return sample

def sample(self):
samples = torch.multinomial(self.probs, 1, True)
def sample(self, action_mask=None):
probs = self.probs
if action_mask is not None:
probs = probs * action_mask
all_zero = (probs.sum(dim=-1) == 0).unsqueeze(-1)
epsilons = torch.full_like(probs, 1e-6)
probs = torch.where(all_zero, epsilons, probs) # ensure sum of probabilities is non-zero

samples = torch.multinomial(probs, 1, True)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just checking if we don't have to re-normalize the probabilities here so they add up to 1.
Does torch.multinomial do this internally?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, it seems there is no need to add up to 1, according to the doc:

The rows of input do not need to sum to one (in which case we use > the values as weights), ...

https://pytorch.org/docs/stable/generated/torch.multinomial.html

But I'm not so sure honestly (I'm a newbie on RL). So please feel free to fix if you see something wrong with the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, seems it requires to normalize the value with softmax as far as I understand, so implemented that.

return samples

def log_prob(self, value):
Expand Down Expand Up @@ -209,18 +225,27 @@ def _calc_log_probs(self, list_of_action_batches):

return log_probs

def sample_actions_log_probs(self):
list_of_action_batches = [d.sample() for d in self.distributions]
def sample_actions_log_probs(self, action_mask=None):
action_mask = [action_mask[i] if action_mask is not None else None for i in range(len(self.distributions))]
list_of_action_batches = [
d.sample() if isinstance(d, ContinuousActionDistribution) else d.sample(action_mask[i])
nkzawa marked this conversation as resolved.
Show resolved Hide resolved
for i, d in enumerate(self.distributions)
]
batch_of_action_tuples = self._flatten_actions(list_of_action_batches)
log_probs = self._calc_log_probs(list_of_action_batches)
return batch_of_action_tuples, log_probs

def sample(self):
list_of_action_batches = [d.sample() for d in self.distributions]
def sample(self, action_mask=None):
action_mask = [action_mask[i] if action_mask is not None else None for i in range(len(self.distributions))]
list_of_action_batches = [
d.sample() if isinstance(d, ContinuousActionDistribution) else d.sample(action_mask[i])
nkzawa marked this conversation as resolved.
Show resolved Hide resolved
for i, d in enumerate(self.distributions)
]
return self._flatten_actions(list_of_action_batches)

def argmax(self):
list_of_action_batches = [argmax_actions(d) for d in self.distributions]
def argmax(self, action_mask=None):
action_mask = [action_mask[i] if action_mask is not None else None for i in range(len(self.distributions))]
list_of_action_batches = [argmax_actions(d, action_mask[i]) for i, d in enumerate(self.distributions)]
return torch.cat(list_of_action_batches).unsqueeze(0)

def log_prob(self, actions):
Expand Down
6 changes: 4 additions & 2 deletions sample_factory/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def max_frames_reached(frames):
reward_list = []

obs, infos = env.reset()
action_mask = obs.pop("action_mask").to(device) if "action_mask" in obs else None
rnn_states = torch.zeros([env.num_agents, get_rnn_size(cfg)], dtype=torch.float32, device=device)
episode_reward = None
finished_episode = [False for _ in range(env.num_agents)]
Expand All @@ -149,14 +150,14 @@ def max_frames_reached(frames):

if not cfg.no_render:
visualize_policy_inputs(normalized_obs)
policy_outputs = actor_critic(normalized_obs, rnn_states)
policy_outputs = actor_critic(normalized_obs, rnn_states, action_mask=action_mask)

# sample actions from the distribution by default
actions = policy_outputs["actions"]

if cfg.eval_deterministic:
action_distribution = actor_critic.action_distribution()
actions = argmax_actions(action_distribution)
actions = argmax_actions(action_distribution, action_mask)

# actions shape should be [num_agents, num_actions] even if it's [1, 1]
if actions.ndim == 1:
Expand All @@ -169,6 +170,7 @@ def max_frames_reached(frames):
last_render_start = render_frame(cfg, env, video_frames, num_episodes, last_render_start)

obs, rew, terminated, truncated, infos = env.step(actions)
action_mask = obs.pop("action_mask").to(device) if "action_mask" in obs else None
dones = make_dones(terminated, truncated)
infos = [{} for _ in range(env_info.num_agents)] if infos is None else infos

Expand Down
79 changes: 79 additions & 0 deletions sample_factory/envs/pettingzoo_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Gym env wrappers for PettingZoo -> Gymnasium transition.
"""

import gymnasium as gym


class PettingZooParallelEnv(gym.Env):
def __init__(self, env):
if not all_equal([env.observation_space(a) for a in env.possible_agents]):
raise ValueError("All observation spaces must be equal")

if not all_equal([env.action_space(a) for a in env.possible_agents]):
raise ValueError("All action spaces must be equal")

self.env = env
self.metadata = env.metadata
self.render_mode = env.render_mode if hasattr(env, "render_mode") else env.unwrapped.render_mode
self.observation_space = normalize_observation_space(env.observation_space(env.possible_agents[0]))
self.action_space = env.action_space(env.possible_agents[0])
self.num_agents = env.max_num_agents
self.is_multiagent = True

def reset(self, **kwargs):
obs, infos = self.env.reset(**kwargs)
obs = [normalize_observation(obs.get(a)) for a in self.env.possible_agents]
infos = [infos[a] if a in infos else dict(is_active=False) for a in self.env.possible_agents]
return obs, infos

def step(self, actions):
actions = dict(zip(self.env.possible_agents, actions))
obs, rewards, terminations, truncations, infos = self.env.step(actions)

if not self.env.agents:
obs, infos = self.env.reset()

obs = [normalize_observation(obs.get(a)) for a in self.env.possible_agents]
rewards = [rewards.get(a) for a in self.env.possible_agents]
terminations = [terminations.get(a) for a in self.env.possible_agents]
truncations = [truncations.get(a) for a in self.env.possible_agents]
infos = [normalize_info(infos[a], a) if a in infos else dict(is_active=False) for a in self.env.possible_agents]
return obs, rewards, terminations, truncations, infos

def render(self):
return self.env.render()

def close(self):
self.env.close()


def all_equal(l_) -> bool:
return all(v == l_[0] for v in l_)


def normalize_observation_space(obs_space):
"""Normalize observation space with the key "obs" that's specially handled as the main value."""
if isinstance(obs_space, gym.spaces.Dict) and "observation" in obs_space.spaces:
spaces = dict(obs_space.spaces)
spaces["obs"] = spaces["observation"]
del spaces["observation"]
obs_space = gym.spaces.Dict(spaces)

return obs_space


def normalize_observation(obs):
if isinstance(obs, dict) and "observation" in obs:
obs["obs"] = obs["observation"]
del obs["observation"]

return obs


def normalize_info(info, agent):
"""active_agent is available when using `turn_based_aec_to_parallel` of PettingZoo."""
if isinstance(info, dict) and "active_agent" in info:
info["is_active"] = info["active_agent"] == agent

return info
50 changes: 37 additions & 13 deletions sample_factory/model/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Dict, Optional

import gymnasium as gym
import torch
from torch import Tensor, nn
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
Expand Down Expand Up @@ -108,10 +109,12 @@ def summaries(self) -> Dict:
def action_distribution(self):
return self.last_action_distribution

def _maybe_sample_actions(self, sample_actions: bool, result: TensorDict) -> None:
def _maybe_sample_actions(
self, sample_actions: bool, result: TensorDict, action_mask: Optional[Tensor] = None
) -> None:
if sample_actions:
# for non-trivial action spaces it is faster to do these together
actions, result["log_prob_actions"] = sample_actions_log_probs(self.last_action_distribution)
actions, result["log_prob_actions"] = sample_actions_log_probs(self.last_action_distribution, action_mask)
assert actions.dim() == 2 # TODO: remove this once we test everything
result["actions"] = actions.squeeze(dim=1)

Expand All @@ -121,10 +124,14 @@ def forward_head(self, normalized_obs_dict: Dict[str, Tensor]) -> Tensor:
def forward_core(self, head_output, rnn_states):
raise NotImplementedError()

def forward_tail(self, core_output, values_only: bool, sample_actions: bool) -> TensorDict:
def forward_tail(
self, core_output, values_only: bool, sample_actions: bool, action_mask: Optional[Tensor] = None
) -> TensorDict:
raise NotImplementedError()

def forward(self, normalized_obs_dict, rnn_states, values_only: bool = False) -> TensorDict:
def forward(
self, normalized_obs_dict, rnn_states, values_only: bool = False, action_mask: Optional[Tensor] = None
) -> TensorDict:
raise NotImplementedError()


Expand Down Expand Up @@ -160,7 +167,9 @@ def forward_core(self, head_output: Tensor, rnn_states):
x, new_rnn_states = self.core(head_output, rnn_states)
return x, new_rnn_states

def forward_tail(self, core_output, values_only: bool, sample_actions: bool) -> TensorDict:
def forward_tail(
self, core_output, values_only: bool, sample_actions: bool, action_mask: Optional[Tensor] = None
) -> TensorDict:
decoder_output = self.decoder(core_output)
values = self.critic_linear(decoder_output).squeeze()

Expand All @@ -173,13 +182,15 @@ def forward_tail(self, core_output, values_only: bool, sample_actions: bool) ->
# `action_logits` is not the best name here, better would be "action distribution parameters"
result["action_logits"] = action_distribution_params

self._maybe_sample_actions(sample_actions, result)
self._maybe_sample_actions(sample_actions, result, action_mask)
return result

def forward(self, normalized_obs_dict, rnn_states, values_only=False) -> TensorDict:
def forward(
self, normalized_obs_dict, rnn_states, values_only=False, action_mask: Optional[Tensor] = None
) -> TensorDict:
x = self.forward_head(normalized_obs_dict)
x, new_rnn_states = self.forward_core(x, rnn_states)
result = self.forward_tail(x, values_only, sample_actions=True)
result = self.forward_tail(x, values_only, sample_actions=True, action_mask=action_mask)
result["new_rnn_states"] = new_rnn_states
return result

Expand Down Expand Up @@ -276,7 +287,9 @@ def forward_head(self, normalized_obs_dict: Dict):
def forward_core(self, head_output, rnn_states):
return self.core_func(head_output, rnn_states)

def forward_tail(self, core_output, values_only: bool, sample_actions: bool) -> TensorDict:
def forward_tail(
self, core_output, values_only: bool, sample_actions: bool, action_mask: Optional[Tensor] = None
) -> TensorDict:
core_outputs = core_output.chunk(len(self.cores), dim=1)

# second core output corresponds to the critic
Expand All @@ -294,13 +307,15 @@ def forward_tail(self, core_output, values_only: bool, sample_actions: bool) ->

result["action_logits"] = action_distribution_params

self._maybe_sample_actions(sample_actions, result)
self._maybe_sample_actions(sample_actions, result, action_mask)
return result

def forward(self, normalized_obs_dict, rnn_states, values_only=False) -> TensorDict:
def forward(
self, normalized_obs_dict, rnn_states, values_only=False, action_mask: Optional[Tensor] = None
) -> TensorDict:
x = self.forward_head(normalized_obs_dict)
x, new_rnn_states = self.forward_core(x, rnn_states)
result = self.forward_tail(x, values_only, sample_actions=True)
result = self.forward_tail(x, values_only, sample_actions=True, action_mask=action_mask)
result["new_rnn_states"] = new_rnn_states
return result

Expand All @@ -321,4 +336,13 @@ def create_actor_critic(cfg: Config, obs_space: ObsSpace, action_space: ActionSp
from sample_factory.algo.utils.context import global_model_factory

make_actor_critic_func = global_model_factory().make_actor_critic_func
return make_actor_critic_func(cfg, obs_space, action_space)
return make_actor_critic_func(cfg, obs_space_without_action_mask(obs_space), action_space)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be a bit cleaner to add special treatment for action_mask inside make_actor_critic_func but I'm fine with this solution too 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed that it meant to do inside default_make_actor_critic_func. Fixed it so anyway 🙏



def obs_space_without_action_mask(obs_space: ObsSpace) -> ObsSpace:
if isinstance(obs_space, gym.spaces.Dict) and "action_mask" in obs_space.spaces:
spaces = obs_space.spaces.copy()
del spaces["action_mask"]
obs_space = gym.spaces.Dict(spaces)

return obs_space
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"debugpy ~= 1.6",
]
_envpool_deps = ["envpool"]
_pettingzoo_deps = ["pettingzoo[classic]"]

_docs_deps = [
"mkdocs-material",
Expand Down Expand Up @@ -80,11 +81,13 @@ def is_macos():
"dev": ["black", "isort>=5.12", "pytest<8.0", "flake8", "pre-commit", "twine"]
+ _docs_deps
+ _atari_deps
+ _mujoco_deps,
+ _mujoco_deps
+ _pettingzoo_deps,
"atari": _atari_deps,
"envpool": _envpool_deps,
"mujoco": _mujoco_deps,
"nethack": _nethack_deps,
"pettingzoo": _pettingzoo_deps,
"vizdoom": ["vizdoom<2.0", "gymnasium[classic_control]"],
# "dmlab": ["dm_env"], <-- these are just auxiliary packages, the main package has to be built from sources
},
Expand Down
Loading
Loading