Skip to content
41 changes: 19 additions & 22 deletions rllib/env/base_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
Union
Union, Set

import gym
import ray
Expand Down Expand Up @@ -198,14 +198,13 @@ def get_sub_environments(
return []

@PublicAPI
def get_agent_ids(self) -> Dict[EnvID, List[AgentID]]:
"""Return the agent ids for each sub-environment.
def get_agent_ids(self) -> Set[AgentID]:
"""Return the agent ids for the sub_environment.

Returns:
A dict mapping from env_id to a list of agent_ids.
All agent ids for each the environment.
"""
logger.warning("get_agent_ids() has not been implemented")
return {}
return {_DUMMY_AGENT_ID}

@PublicAPI
def try_render(self, env_id: Optional[EnvID] = None) -> None:
Expand Down Expand Up @@ -234,8 +233,8 @@ def get_unwrapped(self) -> List[EnvType]:

@PublicAPI
@property
def observation_space(self) -> gym.spaces.Dict:
"""Returns the observation space for each environment.
def observation_space(self) -> gym.Space:
"""Returns the observation space for each agent.

Note: samples from the observation space need to be preprocessed into a
`MultiEnvDict` before being used by a policy.
Expand All @@ -248,7 +247,7 @@ def observation_space(self) -> gym.spaces.Dict:
@PublicAPI
@property
def action_space(self) -> gym.Space:
"""Returns the action space for each environment.
"""Returns the action space for each agent.

Note: samples from the action space need to be preprocessed into a
`MultiEnvDict` before being passed to `send_actions`.
Expand All @@ -270,6 +269,7 @@ def action_space_sample(self, agent_id: list = None) -> MultiEnvDict:
Returns:
A random action for each environment.
"""
logger.warning("action_space_sample() has not been implemented")
del agent_id
return {}

Expand All @@ -286,6 +286,7 @@ def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict:
A random action for each environment.
"""
logger.warning("observation_space_sample() has not been implemented")
del agent_id
return {}

@PublicAPI
Expand Down Expand Up @@ -326,8 +327,7 @@ def action_space_contains(self, x: MultiEnvDict) -> bool:
"""
return self._space_contains(self.action_space, x)

@staticmethod
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
def _space_contains(self, space: gym.Space, x: MultiEnvDict) -> bool:
"""Check if the given space contains the observations of x.

Args:
Expand All @@ -337,17 +337,14 @@ def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
Returns:
True if the observations of x are contained in space.
"""
# this removes the agent_id key and inner dicts
# in MultiEnvDicts
flattened_obs = {
env_id: list(obs.values())
for env_id, obs in x.items()
}
ret = True
for env_id in flattened_obs:
for obs in flattened_obs[env_id]:
ret = ret and space[env_id].contains(obs)
return ret
agents = set(self.get_agent_ids())
for multi_agent_dict in x.values():
for agent_id, obs in multi_agent_dict:
if (agent_id not in agents) or (
not space[agent_id].contains(obs)):
return False

return True


# Fixed agent identifier when there is only the single agent in the env
Expand Down
216 changes: 199 additions & 17 deletions rllib/env/multi_agent_env.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import gym
from typing import Callable, Dict, List, Tuple, Type, Optional, Union
import logging
from typing import Callable, Dict, List, Tuple, Type, Optional, Union, Set

from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.utils.annotations import ExperimentalAPI, override, PublicAPI
from ray.rllib.utils.annotations import ExperimentalAPI, override, PublicAPI, \
DeveloperAPI
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \
MultiEnvDict

# If the obs space is Dict type, look for the global state under this key.
ENV_STATE = "state"

logger = logging.getLogger(__name__)


@PublicAPI
class MultiAgentEnv(gym.Env):
Expand All @@ -20,6 +24,15 @@ class MultiAgentEnv(gym.Env):
referred to as "agents" or "RL agents".
"""

def __init__(self):
self.observation_space = None
self.action_space = None
self._agent_ids = {}

# do the action and observation spaces map from agent ids to spaces
# for the individual agents?
self._spaces_in_preferred_format = None

@PublicAPI
def reset(self) -> MultiAgentDict:
"""Resets the env and returns observations from ready agents.
Expand Down Expand Up @@ -81,20 +94,127 @@ def step(
"""
raise NotImplementedError

@ExperimentalAPI
def observation_space_contains(self, x: MultiAgentDict) -> bool:
"""Checks if the observation space contains the given key.

Args:
x: Observations to check.

Returns:
True if the observation space contains the given all observations
in x.
"""
if not hasattr(self, "_spaces_in_preferred_format") or \
self._spaces_in_preferred_format is None:
self._spaces_in_preferred_format = \
self._check_if_space_maps_agent_id_to_sub_space()
if self._spaces_in_preferred_format:
return self.observation_space.contains(x)

logger.warning("observation_space_contains() has not been implemented")
return True

@ExperimentalAPI
def action_space_contains(self, x: MultiAgentDict) -> bool:
"""Checks if the action space contains the given action.

Args:
x: Actions to check.

Returns:
True if the action space contains all actions in x.
"""
if not hasattr(self, "_spaces_in_preferred_format") or \
self._spaces_in_preferred_format is None:
self._spaces_in_preferred_format = \
self._check_if_space_maps_agent_id_to_sub_space()
if self._spaces_in_preferred_format:
return self.action_space.contains(x)

logger.warning("action_space_contains() has not been implemented")
return True

@ExperimentalAPI
def action_space_sample(self, agent_ids: list = None) -> MultiAgentDict:
"""Returns a random action for each environment, and potentially each
agent in that environment.

Args:
agent_ids: List of agent ids to sample actions for. If None or
empty list, sample actions for all agents in the
environment.

Returns:
A random action for each environment.
"""
if not hasattr(self, "_spaces_in_preferred_format") or \
self._spaces_in_preferred_format is None:
self._spaces_in_preferred_format = \
self._check_if_space_maps_agent_id_to_sub_space()
if self._spaces_in_preferred_format:
if agent_ids is None:
agent_ids = self.get_agent_ids()
samples = self.action_space.sample()
return {agent_id: samples[agent_id] for agent_id in agent_ids}
logger.warning("action_space_sample() has not been implemented")
del agent_ids
return {}

@ExperimentalAPI
def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
"""Returns a random observation from the observation space for each
agent if agent_ids is None, otherwise returns a random observation for
the agents in agent_ids.

Args:
agent_ids: List of agent ids to sample actions for. If None or
empty list, sample actions for all agents in the
environment.

Returns:
A random action for each environment.
"""

if not hasattr(self, "_spaces_in_preferred_format") or \
self._spaces_in_preferred_format is None:
self._spaces_in_preferred_format = \
self._check_if_space_maps_agent_id_to_sub_space()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a pure software engineering thing, any reason we don't do this in init but here instead, in a utility function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users will inherit from this class, however they may not necessarily call this function inside of their init functions. If I added this function to init, it may get called before the user has defined their observation/action spaces.

Although calling it here is messy, it allows for a nice external functionality, where if a user defines their spaces in the preferred format, they do not need to implement any of these sampling or checking methods on their own.

if self._spaces_in_preferred_format:
if agent_ids is None:
agent_ids = self.get_agent_ids()
samples = self.observation_space.sample()
samples = {agent_id: samples[agent_id] for agent_id in agent_ids}
return samples
logger.warning("observation_space_sample() has not been implemented")
del agent_ids
return {}

@PublicAPI
def get_agent_ids(self) -> Set[AgentID]:
"""Returns a set of agent ids in the environment.

Returns:
set of agent ids.
"""
if not isinstance(self._agent_ids, set):
self._agent_ids = set(self._agent_ids)
return self._agent_ids

@PublicAPI
def render(self, mode=None) -> None:
"""Tries to render the environment."""

# By default, do nothing.
pass

# yapf: disable
# __grouping_doc_begin__
# yapf: disable
# __grouping_doc_begin__
@ExperimentalAPI
def with_agent_groups(
self,
groups: Dict[str, List[AgentID]],
obs_space: gym.Space = None,
self,
groups: Dict[str, List[AgentID]],
obs_space: gym.Space = None,
act_space: gym.Space = None) -> "MultiAgentEnv":
"""Convenience method for grouping together agents in this env.

Expand Down Expand Up @@ -132,8 +252,9 @@ def with_agent_groups(
from ray.rllib.env.wrappers.group_agents_wrapper import \
GroupAgentsWrapper
return GroupAgentsWrapper(self, groups, obs_space, act_space)
# __grouping_doc_end__
# yapf: enable

# __grouping_doc_end__
# yapf: enable

@PublicAPI
def to_base_env(
Expand Down Expand Up @@ -182,6 +303,20 @@ def to_base_env(

return env

@DeveloperAPI
def _check_if_space_maps_agent_id_to_sub_space(self) -> bool:
# do the action and observation spaces map from agent ids to spaces
# for the individual agents?
obs_space_check = (
hasattr(self, "observation_space")
and isinstance(self.observation_space, gym.spaces.Dict)
and set(self.observation_space.keys()) == self.get_agent_ids())
action_space_check = (
hasattr(self, "action_space")
and isinstance(self.action_space, gym.spaces.Dict)
and set(self.action_space.keys()) == self.get_agent_ids())
return obs_space_check and action_space_check


def make_multi_agent(
env_name_or_creator: Union[str, Callable[[EnvContext], EnvType]],
Expand Down Expand Up @@ -242,6 +377,40 @@ def __init__(self, config=None):
self.dones = set()
self.observation_space = self.agents[0].observation_space
self.action_space = self.agents[0].action_space
self._agent_ids = set(range(num))

@override(MultiAgentEnv)
def observation_space_sample(self,
agent_ids: list = None) -> MultiAgentDict:
if agent_ids is None:
agent_ids = list(range(len(self.agents)))
obs = {
agent_id: self.observation_space.sample()
for agent_id in agent_ids
}

return obs

@override(MultiAgentEnv)
def action_space_sample(self,
agent_ids: list = None) -> MultiAgentDict:
if agent_ids is None:
agent_ids = list(range(len(self.agents)))
actions = {
agent_id: self.action_space.sample()
for agent_id in agent_ids
}

return actions

@override(MultiAgentEnv)
def action_space_contains(self, x: MultiAgentDict) -> bool:
return all(self.action_space.contains(val) for val in x.values())

@override(MultiAgentEnv)
def observation_space_contains(self, x: MultiAgentDict) -> bool:
return all(
self.observation_space.contains(val) for val in x.values())

@override(MultiAgentEnv)
def reset(self):
Expand Down Expand Up @@ -277,7 +446,7 @@ def __init__(self, make_env: Callable[[int], EnvType],

Args:
make_env (Callable[[int], EnvType]): Factory that produces a new
MultiAgentEnv intance. Must be defined, if the number of
MultiAgentEnv instance. Must be defined, if the number of
existing envs is less than num_envs.
existing_envs (List[MultiAgentEnv]): List of already existing
multi-agent envs.
Expand Down Expand Up @@ -355,18 +524,31 @@ def try_render(self, env_id: Optional[EnvID] = None) -> None:
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
space = {
_id: env.observation_space
for _id, env in enumerate(self.envs)
}
return gym.spaces.Dict(space)
self.envs[0].observation_space

@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
space = {_id: env.action_space for _id, env in enumerate(self.envs)}
return gym.spaces.Dict(space)
return self.envs[0].action_space

@override(BaseEnv)
def observation_space_contains(self, x: MultiEnvDict) -> bool:
return all(
self.envs[0].observation_space_contains(val) for val in x.values())

@override(BaseEnv)
def action_space_contains(self, x: MultiEnvDict) -> bool:
return all(
self.envs[0].action_space_contains(val) for val in x.values())

@override(BaseEnv)
def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
return self.envs[0].observation_space_sample(agent_ids)

@override(BaseEnv)
def action_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
return self.envs[0].action_space_sample(agent_ids)


class _MultiAgentEnvState:
Expand Down
Loading