Skip to content

Commit

Permalink
Added Action masking for Space.sample() (#2906)
Browse files Browse the repository at this point in the history
* Allows a new RNG to be generated with seed=-1 and updated env_checker to fix bug if environment doesn't use np_random in reset

* Revert "fixed `gym.vector.make` where the checker was being applied in the opposite case than was intended to (#2871)"

This reverts commit 519dfd9.

* Remove bad pushed commits

* Fixed spelling in core.py

* Pins pytest to the last py 3.6 version

* Add support for action masking in Space.sample(mask=...)

* Fix action mask

* Fix action_mask

* Fix action_mask

* Added docstrings, fixed bugs and added taxi examples

* Fixed bugs

* Add tests for sample

* Add docstrings and test space sample mask Discrete and MultiBinary

* Add MultiDiscrete sampling and tests

* Remove sample mask from graph

* Update gym/spaces/multi_discrete.py

Co-authored-by: Markus Krimmel <[email protected]>

* Updates based on Marcus28 and jjshoots for Graph.py

* Updates based on Marcus28 and jjshoots for Graph.py

* jjshoot review

* jjshoot review

* Update assert check

* Update type hints

Co-authored-by: Markus Krimmel <[email protected]>
  • Loading branch information
pseudo-rnd-thoughts and Markus28 authored Jun 26, 2022
1 parent d750eb8 commit 024b0f5
Show file tree
Hide file tree
Showing 11 changed files with 564 additions and 73 deletions.
44 changes: 41 additions & 3 deletions gym/envs/toy_text/taxi.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ class TaxiEnv(Env):
- 2: Y(ellow)
- 3: B(lue)
### Info
``step`` and ``reset(return_info=True)`` will return an info dictionary that contains "p" and "action_mask" containing
the probability that the state is taken and a mask of what actions will result in a change of state to speed up training.
As Taxi's initial state is a stochastic, the "p" key represents the probability of the
transition however this value is currently bugged being 1.0, this will be fixed soon.
As the steps are deterministic, "p" represents the probability of the transition which is always 1.0
For some cases, taking an action will have no effect on the state of the agent.
In v0.25.0, ``info["action_mask"]`` contains a np.ndarray for each of the action specifying
if the action will change the state.
To sample a modifying action, use ``action = env.action_space.sample(info["action_mask"])``
Or with a Q-value based algorithm ``action = np.argmax(q_values[obs, np.where(info["action_mask"] == 1)[0]])``.
### Rewards
- -1 per step unless other reward is triggered.
- +20 delivering passenger.
Expand All @@ -99,7 +115,7 @@ class TaxiEnv(Env):
```
### Version History
* v3: Map Correction + Cleaner Domain Description
* v3: Map Correction + Cleaner Domain Description, v0.25.0 action masking added to the reset and step information
* v2: Disallow Taxi start location = goal location, Update Taxi observations in the rollout, Update Taxi reward threshold.
* v1: Remove (3,2) from locs, add passidx<4 check
* v0: Initial versions release
Expand Down Expand Up @@ -214,14 +230,36 @@ def decode(self, i):
assert 0 <= i < 5
return reversed(out)

def action_mask(self, state: int):
"""Computes an action mask for the action space using the state information."""
mask = np.zeros(6, dtype=np.int8)
taxi_row, taxi_col, pass_loc, dest_idx = self.decode(state)
if taxi_row < 4:
mask[0] = 1
if taxi_row > 0:
mask[1] = 1
if taxi_col < 4 and self.desc[taxi_row + 1, 2 * taxi_col + 2] == b":":
mask[2] = 1
if taxi_col > 0 and self.desc[taxi_row + 1, 2 * taxi_col] == b":":
mask[3] = 1
if pass_loc < 4 and (taxi_row, taxi_col) == self.locs[pass_loc]:
mask[4] = 1
if pass_loc == 4 and (
(taxi_row, taxi_col) == self.locs[dest_idx]
or (taxi_row, taxi_col) in self.locs
):
mask[5] = 1
return mask

def step(self, a):
transitions = self.P[self.s][a]
i = categorical_sample([t[0] for t in transitions], self.np_random)
p, s, r, d = transitions[i]
self.s = s
self.lastaction = a
self.renderer.render_step()
return (int(s), r, d, {"prob": p})

return int(s), r, d, {"prob": p, "action_mask": self.action_mask(s)}

def reset(
self,
Expand All @@ -239,7 +277,7 @@ def reset(
if not return_info:
return int(self.s)
else:
return int(self.s), {"prob": 1}
return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)}

def render(self, mode="human"):
if self.render_mode is not None:
Expand Down
11 changes: 10 additions & 1 deletion gym/spaces/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np

import gym.error
from gym import logger
from gym.spaces.space import Space
from gym.utils import seeding
Expand Down Expand Up @@ -146,7 +147,7 @@ def is_bounded(self, manner: str = "both") -> bool:
else:
raise ValueError("manner is not in {'below', 'above', 'both'}")

def sample(self) -> np.ndarray:
def sample(self, mask: None = None) -> np.ndarray:
r"""Generates a single random sample inside the Box.
In creating a sample of the box, each coordinate is sampled (independently) from a distribution
Expand All @@ -157,9 +158,17 @@ def sample(self) -> np.ndarray:
* :math:`(-\infty, b]` : shifted negative exponential distribution
* :math:`(-\infty, \infty)` : normal distribution
Args:
mask: A mask for sampling values from the Box space, currently unsupported.
Returns:
A sampled value from the Box
"""
if mask is not None:
raise gym.error.Error(
f"Box.sample cannot be provided a mask, actual value: {mask}"
)

high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
sample = np.empty(self.shape)

Expand Down
17 changes: 16 additions & 1 deletion gym/spaces/dict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Implementation of a space that represents the cartesian product of other spaces as a dictionary."""
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from typing import Any
from typing import Dict as TypingDict
from typing import Optional, Union

Expand Down Expand Up @@ -137,14 +138,28 @@ def seed(self, seed: Optional[Union[dict, int]] = None) -> list:

return seeds

def sample(self) -> dict:
def sample(self, mask: Optional[TypingDict[str, Any]] = None) -> dict:
"""Generates a single random sample from this space.
The sample is an ordered dictionary of independent samples from the constituent spaces.
Args:
mask: An optional mask for each of the subspaces, expects the same keys as the space
Returns:
A dictionary with the same key and sampled values from :attr:`self.spaces`
"""
if mask is not None:
assert isinstance(
mask, dict
), f"Expects mask to be a dict, actual type: {type(mask)}"
assert (
mask.keys() == self.spaces.keys()
), f"Expect mask keys to be same as space keys, mask keys: {mask.keys()}, space keys: {self.spaces.keys()}"
return OrderedDict(
[(k, space.sample(mask[k])) for k, space in self.spaces.items()]
)

return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])

def contains(self, x) -> bool:
Expand Down
30 changes: 28 additions & 2 deletions gym/spaces/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,40 @@ def __init__(
self.start = int(start)
super().__init__((), np.int64, seed)

def sample(self) -> int:
def sample(self, mask: Optional[np.ndarray] = None) -> int:
"""Generates a single random sample from this space.
A sample will be chosen uniformly at random.
A sample will be chosen uniformly at random with the mask if provided
Args:
mask: An optional mask for if an action can be selected.
Expected `np.ndarray` of shape `(n,)` and dtype `np.int8` where `1` represents valid actions and `0` invalid / infeasible actions.
If there are no possible actions (i.e. `np.all(mask == 0)`) then `space.start` will be returned.
Returns:
A sampled integer from the space
"""
if mask is not None:
assert isinstance(
mask, np.ndarray
), f"The expected type of the mask is np.ndarray, actual type: {type(mask)}"
assert (
mask.dtype == np.int8
), f"The expected dtype of the mask is np.int8, actual dtype: {mask.dtype}"
assert mask.shape == (
self.n,
), f"The expected shape of the mask is {(self.n,)}, actual shape: {mask.shape}"
valid_action_mask = mask == 1
assert np.all(
np.logical_or(mask == 0, valid_action_mask)
), f"All values of a mask should be 0 or 1, actual values: {mask}"
if np.any(valid_action_mask):
return int(
self.start + self.np_random.choice(np.where(valid_action_mask)[0])
)
else:
return self.start

return int(self.start + self.np_random.integers(self.n))

def contains(self, x) -> bool:
Expand Down
85 changes: 56 additions & 29 deletions gym/spaces/graph.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
from collections import namedtuple
from typing import NamedTuple, Optional, Sequence, Union
from typing import NamedTuple, Optional, Sequence, Tuple, Union

import numpy as np

from gym.spaces.box import Box
from gym.spaces.discrete import Discrete
from gym.spaces.multi_discrete import MultiDiscrete
from gym.spaces.multi_discrete import SAMPLE_MASK_TYPE, MultiDiscrete
from gym.spaces.space import Space
from gym.utils import seeding

Expand Down Expand Up @@ -70,53 +70,80 @@ def __init__(

def _generate_sample_space(
self, base_space: Union[None, Box, Discrete], num: int
) -> Optional[Union[Box, Discrete]]:
# the possibility of this space , got {type(base_space)}aving nothing
if num == 0:
) -> Optional[Union[Box, MultiDiscrete]]:
if num == 0 or base_space is None:
return None

if isinstance(base_space, Box):
return Box(
low=np.array(max(1, num) * [base_space.low]),
high=np.array(max(1, num) * [base_space.high]),
shape=(num, *base_space.shape),
shape=(num,) + base_space.shape,
dtype=base_space.dtype,
seed=self._np_random,
seed=self.np_random,
)
elif isinstance(base_space, Discrete):
return MultiDiscrete(nvec=[base_space.n] * num, seed=self._np_random)
elif base_space is None:
return None
return MultiDiscrete(nvec=[base_space.n] * num, seed=self.np_random)
else:
raise AssertionError(
f"Only Box and Discrete can be accepted as a base_space, got {type(base_space)}, you should not have gotten this error."
f"Expects base space to be Box and Discrete, actual space: {type(base_space)}."
)

def _sample_sample_space(self, sample_space) -> Optional[np.ndarray]:
if sample_space is not None:
return sample_space.sample()
else:
return None

def sample(self) -> NamedTuple:
def sample(
self,
mask: Optional[
Tuple[
Optional[Union[np.ndarray, SAMPLE_MASK_TYPE]],
Optional[Union[np.ndarray, SAMPLE_MASK_TYPE]],
]
] = None,
num_nodes: int = 10,
num_edges: Optional[int] = None,
) -> NamedTuple:
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.
Args:
mask: An optional tuple of optional node and edge mask that is only possible with Discrete spaces
(Box spaces don't support sample masks).
If no `num_edges` is provided then the `edge_mask` is multiplied by the number of edges
num_nodes: The number of nodes that will be sampled, the default is 10 nodes
num_edges: An optional number of edges, otherwise, a random number between 0 and `num_nodes`^2
Returns:
A NamedTuple representing a graph with attributes .nodes, .edges, and .edge_links.
"""
num_nodes = self.np_random.integers(low=1, high=10)
assert (
num_nodes > 0
), f"The number of nodes is expected to be greater than 0, actual value: {num_nodes}"

# we only have edges when we have at least 2 nodes
num_edges = 0
if num_nodes > 1:
# maximal number of edges is (n*n) allowing self connections and two way is allowed
num_edges = self.np_random.integers(num_nodes * num_nodes)

node_sample_space = self._generate_sample_space(self.node_space, num_nodes)
edge_sample_space = self._generate_sample_space(self.edge_space, num_edges)
if mask is not None:
node_space_mask, edge_space_mask = mask
else:
node_space_mask, edge_space_mask = None, None

sampled_nodes = self._sample_sample_space(node_sample_space)
sampled_edges = self._sample_sample_space(edge_sample_space)
# we only have edges when we have at least 2 nodes
if num_edges is None:
if num_nodes > 1:
# maximal number of edges is `n*(n-1)` allowing self connections and two-way is allowed
num_edges = self.np_random.integers(num_nodes * (num_nodes - 1))
else:
num_edges = 0
if edge_space_mask is not None:
edge_space_mask = tuple(edge_space_mask for _ in range(num_edges))
else:
assert (
num_edges >= 0
), f"The number of edges is expected to be greater than 0, actual mask: {num_edges}"

sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges)

sampled_nodes = sampled_node_space.sample(node_space_mask)
sampled_edges = (
sampled_edge_space.sample(edge_space_mask)
if sampled_edge_space is not None
else None
)

sampled_edge_links = None
if sampled_edges is not None and num_edges > 0:
Expand Down
24 changes: 23 additions & 1 deletion gym/spaces/multi_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,36 @@ def shape(self) -> Tuple[int, ...]:
"""Has stricter type than gym.Space - never None."""
return self._shape # type: ignore

def sample(self) -> np.ndarray:
def sample(self, mask: Optional[np.ndarray] = None) -> np.ndarray:
"""Generates a single random sample from this space.
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
Args:
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``.
Where mask == 0 then the samples will be 0.
Returns:
Sampled values from space
"""
if mask is not None:
assert isinstance(
mask, np.ndarray
), f"The expected type of the mask is np.ndarray, actual type: {type(mask)}"
assert (
mask.dtype == np.int8
), f"The expected dtype of the mask is np.int8, actual dtype: {mask.dtype}"
assert (
mask.shape == self.shape
), f"The expected shape of the mask is {self.shape}, actual shape: {mask.shape}"
assert np.all(
np.logical_or(mask == 0, mask == 1)
), f"All values of a mask should be 0 or 1, actual values: {mask}"

return mask * self.np_random.integers(
low=0, high=2, size=self.n, dtype=self.dtype
)

return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)

def contains(self, x) -> bool:
Expand Down
Loading

0 comments on commit 024b0f5

Please sign in to comment.