Skip to content

Commit

Permalink
support PettingZoo Parallel API and action mask
Browse files Browse the repository at this point in the history
  • Loading branch information
nkzawa committed Sep 15, 2024
1 parent 91d4322 commit d776e31
Show file tree
Hide file tree
Showing 13 changed files with 390 additions and 45 deletions.
5 changes: 4 additions & 1 deletion sample_factory/algo/sampling/inference_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,14 @@ def _handle_policy_steps(self, timing):
if actor_critic.training:
actor_critic.eval() # need to call this because we can be in serial mode

action_mask = (
ensure_torch_tensor(obs.pop("action_mask")).to(self.device) if "action_mask" in obs else None
)
normalized_obs = prepare_and_normalize_obs(actor_critic, obs)
rnn_states = ensure_torch_tensor(rnn_states).to(self.device).float()

with timing.add_time("forward"):
policy_outputs = actor_critic(normalized_obs, rnn_states)
policy_outputs = actor_critic(normalized_obs, rnn_states, action_mask=action_mask)
policy_outputs["policy_version"] = torch.empty([num_samples]).fill_(self.param_client.policy_version)

with timing.add_time("prepare_outputs"):
Expand Down
2 changes: 1 addition & 1 deletion sample_factory/algo/sampling/non_batched_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def _reset(self):

log.info("Decorrelating experience for %d frames...", decorrelate_steps)
for decorrelate_step in range(decorrelate_steps):
actions = [e.action_space.sample() for _ in range(self.num_agents)]
actions = [e.action_space.sample(obs.get("action_mask")) for obs in observations]
observations, rew, terminated, truncated, info = e.step(actions)

for agent_i, obs in enumerate(observations):
Expand Down
57 changes: 41 additions & 16 deletions sample_factory/algo/utils/action_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,26 @@ def get_action_distribution(action_space, raw_logits):
raise NotImplementedError(f"Action space type {type(action_space)} not supported!")


def sample_actions_log_probs(distribution):
def sample_actions_log_probs(distribution, action_mask=None):
if isinstance(distribution, TupleActionDistribution):
return distribution.sample_actions_log_probs()
return distribution.sample_actions_log_probs(action_mask)
else:
actions = distribution.sample()
if isinstance(distribution, ContinuousActionDistribution):
actions = distribution.sample()
else:
actions = distribution.sample(action_mask)
log_prob_actions = distribution.log_prob(actions)
return actions, log_prob_actions


def argmax_actions(distribution):
def argmax_actions(distribution, action_mask=None):
if isinstance(distribution, TupleActionDistribution):
return distribution.argmax()
return distribution.argmax(action_mask)
elif hasattr(distribution, "probs"):
return torch.argmax(distribution.probs, dim=-1)
probs = distribution.probs
if action_mask is not None:
probs = probs * action_mask
return torch.argmax(probs, dim=-1)
elif hasattr(distribution, "means"):
return distribution.means
else:
Expand Down Expand Up @@ -104,12 +110,22 @@ def log_probs(self):
self.log_p = functional.log_softmax(self.raw_logits, dim=-1)
return self.log_p

def sample_gumbel(self):
sample = torch.argmax(self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_(), -1)
def sample_gumbel(self, action_mask=None):
probs = self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_()
if action_mask is not None:
probs = probs * action_mask
sample = torch.argmax(probs, -1)
return sample

def sample(self):
samples = torch.multinomial(self.probs, 1, True)
def sample(self, action_mask=None):
probs = self.probs
if action_mask is not None:
probs = probs * action_mask
all_zero = (probs.sum(dim=-1) == 0).unsqueeze(-1)
epsilons = torch.full_like(probs, 1e-6)
probs = torch.where(all_zero, epsilons, probs) # ensure sum of probabilities is non-zero

samples = torch.multinomial(probs, 1, True)
return samples

def log_prob(self, value):
Expand Down Expand Up @@ -209,18 +225,27 @@ def _calc_log_probs(self, list_of_action_batches):

return log_probs

def sample_actions_log_probs(self):
list_of_action_batches = [d.sample() for d in self.distributions]
def sample_actions_log_probs(self, action_mask=None):
action_mask = [action_mask[i] if action_mask is not None else None for i in range(len(self.distributions))]
list_of_action_batches = [
d.sample() if isinstance(d, ContinuousActionDistribution) else d.sample(action_mask[i])
for i, d in enumerate(self.distributions)
]
batch_of_action_tuples = self._flatten_actions(list_of_action_batches)
log_probs = self._calc_log_probs(list_of_action_batches)
return batch_of_action_tuples, log_probs

def sample(self):
list_of_action_batches = [d.sample() for d in self.distributions]
def sample(self, action_mask=None):
action_mask = [action_mask[i] if action_mask is not None else None for i in range(len(self.distributions))]
list_of_action_batches = [
d.sample() if isinstance(d, ContinuousActionDistribution) else d.sample(action_mask[i])
for i, d in enumerate(self.distributions)
]
return self._flatten_actions(list_of_action_batches)

def argmax(self):
list_of_action_batches = [argmax_actions(d) for d in self.distributions]
def argmax(self, action_mask=None):
action_mask = [action_mask[i] if action_mask is not None else None for i in range(len(self.distributions))]
list_of_action_batches = [argmax_actions(d, action_mask[i]) for i, d in enumerate(self.distributions)]
return torch.cat(list_of_action_batches).unsqueeze(0)

def log_prob(self, actions):
Expand Down
6 changes: 4 additions & 2 deletions sample_factory/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def max_frames_reached(frames):
reward_list = []

obs, infos = env.reset()
action_mask = obs.pop("action_mask").to(device) if "action_mask" in obs else None
rnn_states = torch.zeros([env.num_agents, get_rnn_size(cfg)], dtype=torch.float32, device=device)
episode_reward = None
finished_episode = [False for _ in range(env.num_agents)]
Expand All @@ -149,14 +150,14 @@ def max_frames_reached(frames):

if not cfg.no_render:
visualize_policy_inputs(normalized_obs)
policy_outputs = actor_critic(normalized_obs, rnn_states)
policy_outputs = actor_critic(normalized_obs, rnn_states, action_mask=action_mask)

# sample actions from the distribution by default
actions = policy_outputs["actions"]

if cfg.eval_deterministic:
action_distribution = actor_critic.action_distribution()
actions = argmax_actions(action_distribution)
actions = argmax_actions(action_distribution, action_mask)

# actions shape should be [num_agents, num_actions] even if it's [1, 1]
if actions.ndim == 1:
Expand All @@ -169,6 +170,7 @@ def max_frames_reached(frames):
last_render_start = render_frame(cfg, env, video_frames, num_episodes, last_render_start)

obs, rew, terminated, truncated, infos = env.step(actions)
action_mask = obs.pop("action_mask").to(device) if "action_mask" in obs else None
dones = make_dones(terminated, truncated)
infos = [{} for _ in range(env_info.num_agents)] if infos is None else infos

Expand Down
79 changes: 79 additions & 0 deletions sample_factory/envs/pettingzoo_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Gym env wrappers for PettingZoo -> Gymnasium transition.
"""

import gymnasium as gym


class PettingZooParallelEnv(gym.Env):
def __init__(self, env):
if not all_equal([env.observation_space(a) for a in env.possible_agents]):
raise ValueError("All observation spaces must be equal")

if not all_equal([env.action_space(a) for a in env.possible_agents]):
raise ValueError("All action spaces must be equal")

self.env = env
self.metadata = env.metadata
self.render_mode = env.render_mode if hasattr(env, "render_mode") else env.unwrapped.render_mode
self.observation_space = normalize_observation_space(env.observation_space(env.possible_agents[0]))
self.action_space = env.action_space(env.possible_agents[0])
self.num_agents = env.max_num_agents
self.is_multiagent = True

def reset(self, **kwargs):
obs, infos = self.env.reset(**kwargs)
obs = [normalize_observation(obs.get(a)) for a in self.env.possible_agents]
infos = [infos[a] if a in infos else dict(is_active=False) for a in self.env.possible_agents]
return obs, infos

def step(self, actions):
actions = dict(zip(self.env.possible_agents, actions))
obs, rewards, terminations, truncations, infos = self.env.step(actions)

if not self.env.agents:
obs, infos = self.env.reset()

obs = [normalize_observation(obs.get(a)) for a in self.env.possible_agents]
rewards = [rewards.get(a) for a in self.env.possible_agents]
terminations = [terminations.get(a) for a in self.env.possible_agents]
truncations = [truncations.get(a) for a in self.env.possible_agents]
infos = [normalize_info(infos[a], a) if a in infos else dict(is_active=False) for a in self.env.possible_agents]
return obs, rewards, terminations, truncations, infos

def render(self):
return self.env.render()

def close(self):
self.env.close()


def all_equal(l):
return all(v == l[0] for v in l)


def normalize_observation_space(obs_space):
"""Normalize observation space with the key "obs" that's specially handled as the main value."""
if isinstance(obs_space, gym.spaces.Dict) and "observation" in obs_space.spaces:
spaces = dict(obs_space.spaces)
spaces["obs"] = spaces["observation"]
del spaces["observation"]
obs_space = gym.spaces.Dict(spaces)

return obs_space


def normalize_observation(obs):
if isinstance(obs, dict) and "observation" in obs:
obs["obs"] = obs["observation"]
del obs["observation"]

return obs


def normalize_info(info, agent):
"""active_agent is available when using `turn_based_aec_to_parallel` of PettingZoo."""
if isinstance(info, dict) and "active_agent" in info:
info["is_active"] = info["active_agent"] == agent

return info
50 changes: 37 additions & 13 deletions sample_factory/model/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Dict, Optional

import gymnasium as gym
import torch
from torch import Tensor, nn
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
Expand Down Expand Up @@ -108,10 +109,12 @@ def summaries(self) -> Dict:
def action_distribution(self):
return self.last_action_distribution

def _maybe_sample_actions(self, sample_actions: bool, result: TensorDict) -> None:
def _maybe_sample_actions(
self, sample_actions: bool, result: TensorDict, action_mask: Optional[Tensor] = None
) -> None:
if sample_actions:
# for non-trivial action spaces it is faster to do these together
actions, result["log_prob_actions"] = sample_actions_log_probs(self.last_action_distribution)
actions, result["log_prob_actions"] = sample_actions_log_probs(self.last_action_distribution, action_mask)
assert actions.dim() == 2 # TODO: remove this once we test everything
result["actions"] = actions.squeeze(dim=1)

Expand All @@ -121,10 +124,14 @@ def forward_head(self, normalized_obs_dict: Dict[str, Tensor]) -> Tensor:
def forward_core(self, head_output, rnn_states):
raise NotImplementedError()

def forward_tail(self, core_output, values_only: bool, sample_actions: bool) -> TensorDict:
def forward_tail(
self, core_output, values_only: bool, sample_actions: bool, action_mask: Optional[Tensor] = None
) -> TensorDict:
raise NotImplementedError()

def forward(self, normalized_obs_dict, rnn_states, values_only: bool = False) -> TensorDict:
def forward(
self, normalized_obs_dict, rnn_states, values_only: bool = False, action_mask: Optional[Tensor] = None
) -> TensorDict:
raise NotImplementedError()


Expand Down Expand Up @@ -160,7 +167,9 @@ def forward_core(self, head_output: Tensor, rnn_states):
x, new_rnn_states = self.core(head_output, rnn_states)
return x, new_rnn_states

def forward_tail(self, core_output, values_only: bool, sample_actions: bool) -> TensorDict:
def forward_tail(
self, core_output, values_only: bool, sample_actions: bool, action_mask: Optional[Tensor] = None
) -> TensorDict:
decoder_output = self.decoder(core_output)
values = self.critic_linear(decoder_output).squeeze()

Expand All @@ -173,13 +182,15 @@ def forward_tail(self, core_output, values_only: bool, sample_actions: bool) ->
# `action_logits` is not the best name here, better would be "action distribution parameters"
result["action_logits"] = action_distribution_params

self._maybe_sample_actions(sample_actions, result)
self._maybe_sample_actions(sample_actions, result, action_mask)
return result

def forward(self, normalized_obs_dict, rnn_states, values_only=False) -> TensorDict:
def forward(
self, normalized_obs_dict, rnn_states, values_only=False, action_mask: Optional[Tensor] = None
) -> TensorDict:
x = self.forward_head(normalized_obs_dict)
x, new_rnn_states = self.forward_core(x, rnn_states)
result = self.forward_tail(x, values_only, sample_actions=True)
result = self.forward_tail(x, values_only, sample_actions=True, action_mask=action_mask)
result["new_rnn_states"] = new_rnn_states
return result

Expand Down Expand Up @@ -276,7 +287,9 @@ def forward_head(self, normalized_obs_dict: Dict):
def forward_core(self, head_output, rnn_states):
return self.core_func(head_output, rnn_states)

def forward_tail(self, core_output, values_only: bool, sample_actions: bool) -> TensorDict:
def forward_tail(
self, core_output, values_only: bool, sample_actions: bool, action_mask: Optional[Tensor] = None
) -> TensorDict:
core_outputs = core_output.chunk(len(self.cores), dim=1)

# second core output corresponds to the critic
Expand All @@ -294,13 +307,15 @@ def forward_tail(self, core_output, values_only: bool, sample_actions: bool) ->

result["action_logits"] = action_distribution_params

self._maybe_sample_actions(sample_actions, result)
self._maybe_sample_actions(sample_actions, result, action_mask)
return result

def forward(self, normalized_obs_dict, rnn_states, values_only=False) -> TensorDict:
def forward(
self, normalized_obs_dict, rnn_states, values_only=False, action_mask: Optional[Tensor] = None
) -> TensorDict:
x = self.forward_head(normalized_obs_dict)
x, new_rnn_states = self.forward_core(x, rnn_states)
result = self.forward_tail(x, values_only, sample_actions=True)
result = self.forward_tail(x, values_only, sample_actions=True, action_mask=action_mask)
result["new_rnn_states"] = new_rnn_states
return result

Expand All @@ -321,4 +336,13 @@ def create_actor_critic(cfg: Config, obs_space: ObsSpace, action_space: ActionSp
from sample_factory.algo.utils.context import global_model_factory

make_actor_critic_func = global_model_factory().make_actor_critic_func
return make_actor_critic_func(cfg, obs_space, action_space)
return make_actor_critic_func(cfg, obs_space_without_action_mask(obs_space), action_space)


def obs_space_without_action_mask(obs_space: ObsSpace) -> ObsSpace:
if isinstance(obs_space, gym.spaces.Dict) and "action_mask" in obs_space.spaces:
spaces = obs_space.spaces.copy()
del spaces["action_mask"]
obs_space = gym.spaces.Dict(spaces)

return obs_space
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"debugpy ~= 1.6",
]
_envpool_deps = ["envpool"]
_pettingzoo_deps = ["pettingzoo[classic]"]

_docs_deps = [
"mkdocs-material",
Expand Down Expand Up @@ -80,11 +81,13 @@ def is_macos():
"dev": ["black", "isort>=5.12", "pytest<8.0", "flake8", "pre-commit", "twine"]
+ _docs_deps
+ _atari_deps
+ _mujoco_deps,
+ _mujoco_deps
+ _pettingzoo_deps,
"atari": _atari_deps,
"envpool": _envpool_deps,
"mujoco": _mujoco_deps,
"nethack": _nethack_deps,
"pettingzoo": _pettingzoo_deps,
"vizdoom": ["vizdoom<2.0", "gymnasium[classic_control]"],
# "dmlab": ["dm_env"], <-- these are just auxiliary packages, the main package has to be built from sources
},
Expand Down
16 changes: 16 additions & 0 deletions sf_examples/enjoy_pettingzoo_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import sys

from sample_factory.enjoy import enjoy
from sf_examples.train_pettingzoo_env import parse_custom_args, register_custom_components


def main(): # pragma: no cover
"""Script entry point."""
register_custom_components()
cfg = parse_custom_args(evaluation=True)
status = enjoy(cfg)
return status


if __name__ == "__main__": # pragma: no cover
sys.exit(main())
Loading

0 comments on commit d776e31

Please sign in to comment.