Skip to content

Commit

Permalink
fix to apply mask to probs
Browse files Browse the repository at this point in the history
  • Loading branch information
nkzawa committed Oct 5, 2024
1 parent 8ba3044 commit 1a66a74
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 62 deletions.
2 changes: 1 addition & 1 deletion docs/07-advanced-topics/action-masking.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Action Masking

Action masking is a technique used to restrict the set of actions available to an agent in certain states. This can be particularly useful in environments where some actions are invalid or undesirable in specific situations. Sample Factory supports action masking, allowing you to implement this feature in your custom environments.
Action masking is a technique used to restrict the set of actions available to an agent in certain states. This can be particularly useful in environments where some actions are invalid or undesirable in specific situations. See [paper](https://arxiv.org/abs/2006.14171) for more details.

## Implementing Action Masking

Expand Down
95 changes: 53 additions & 42 deletions sample_factory/algo/utils/action_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def is_continuous_action_space(action_space: ActionSpace) -> bool:
return isinstance(action_space, gym.spaces.Box)


def get_action_distribution(action_space, raw_logits):
def get_action_distribution(action_space, raw_logits, action_mask=None):
"""
Create the distribution object based on provided action space and unprocessed logits.
:param action_space: Gym action space object
Expand All @@ -52,83 +52,98 @@ def get_action_distribution(action_space, raw_logits):
assert calc_num_action_parameters(action_space) == raw_logits.shape[-1]

if isinstance(action_space, gym.spaces.Discrete):
return CategoricalActionDistribution(raw_logits)
return CategoricalActionDistribution(raw_logits, action_mask)
elif isinstance(action_space, gym.spaces.Tuple):
return TupleActionDistribution(action_space, logits_flat=raw_logits)
return TupleActionDistribution(action_space, logits_flat=raw_logits, action_mask=action_mask)
elif isinstance(action_space, gym.spaces.Box):
return ContinuousActionDistribution(params=raw_logits)
else:
raise NotImplementedError(f"Action space type {type(action_space)} not supported!")


def sample_actions_log_probs(distribution, action_mask=None):
def sample_actions_log_probs(distribution):
if isinstance(distribution, TupleActionDistribution):
return distribution.sample_actions_log_probs(action_mask)
return distribution.sample_actions_log_probs()
else:
actions = distribution.sample(action_mask)
actions = distribution.sample()
log_prob_actions = distribution.log_prob(actions)
return actions, log_prob_actions


def argmax_actions(distribution, action_mask=None):
def argmax_actions(distribution):
if isinstance(distribution, TupleActionDistribution):
return distribution.argmax(action_mask)
return distribution.argmax()
elif hasattr(distribution, "probs"):
probs = distribution.probs
if action_mask is not None:
probs = probs * action_mask
return torch.argmax(probs, dim=-1)
return torch.argmax(distribution.probs, dim=-1)
elif hasattr(distribution, "means"):
return distribution.means
else:
raise NotImplementedError(f"Action distribution type {type(distribution)} does not support argmax!")


# Retrieved from AllenNLP:
# https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py#L243
def masked_softmax(logits, mask):
mask = mask.float()
probs = logits * mask + (1 - mask) * -1e9
probs = functional.softmax(probs, dim=-1)
probs = probs * mask
probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-9)
return probs
# To limit numerical errors from large vector elements outside the mask, we zero these out.
result = functional.softmax(logits * mask, dim=-1)
result = result * mask
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
return result


# Retrieved from AllenNLP:
# https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py#L286
def masked_log_softmax(logits, mask):
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it
# results in nans when the whole vector is masked. We need a very small value instead of a
# zero in the mask for these cases.
logits = logits + (mask + 1e-13).log()
return functional.log_softmax(logits, dim=-1)


# noinspection PyAbstractClass
class CategoricalActionDistribution:
def __init__(self, raw_logits):
def __init__(self, raw_logits, action_mask=None):
"""
Ctor.
:param raw_logits: unprocessed logits, typically an output of a fully-connected layer
"""

self.raw_logits = raw_logits
self.action_mask = action_mask
self.log_p = self.p = None

@property
def probs(self):
if self.p is None:
self.p = functional.softmax(self.raw_logits, dim=-1)
if self.action_mask is not None:
self.p = masked_softmax(self.raw_logits, self.action_mask)
else:
self.p = functional.softmax(self.raw_logits, dim=-1)
return self.p

@property
def log_probs(self):
if self.log_p is None:
self.log_p = functional.log_softmax(self.raw_logits, dim=-1)
if self.action_mask is not None:
self.log_p = masked_log_softmax(self.raw_logits, self.action_mask)
else:
self.log_p = functional.log_softmax(self.raw_logits, dim=-1)
return self.log_p

def sample_gumbel(self, action_mask=None):
def sample_gumbel(self):
probs = self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_()
if action_mask is not None:
probs = probs * action_mask
if self.action_mask is not None:
probs = probs * self.action_mask
sample = torch.argmax(probs, -1)
return sample

def sample(self, action_mask=None):
def sample(self):
probs = self.probs
if action_mask is not None:
probs = masked_softmax(self.raw_logits, action_mask)
if self.action_mask is not None:
all_zero = (probs.sum(dim=-1) == 0).unsqueeze(-1)
probs = torch.where(all_zero, self.probs, probs) # ensure sum of probabilities is non-zero
epsilons = torch.full_like(probs, 1e-6)
probs = torch.where(all_zero, epsilons, probs) # ensure sum of probabilities is non-zero

samples = torch.multinomial(probs, 1, True)
return samples
Expand Down Expand Up @@ -202,16 +217,18 @@ class TupleActionDistribution:
"""

def __init__(self, action_space, logits_flat):
def __init__(self, action_space, logits_flat, action_mask=None):
self.logit_lengths = [calc_num_action_parameters(s) for s in action_space.spaces]
self.split_logits = torch.split(logits_flat, self.logit_lengths, dim=1)
self.action_lengths = [calc_num_actions(s) for s in action_space.spaces]
self.action_mask = action_mask

assert len(self.split_logits) == len(action_space.spaces)

self.distributions = []
for i, space in enumerate(action_space.spaces):
self.distributions.append(get_action_distribution(space, self.split_logits[i]))
action_mask = self.action_mask[i] if self.action_mask is not None else None
self.distributions.append(get_action_distribution(space, self.split_logits[i], action_mask))

@staticmethod
def _flatten_actions(list_of_action_batches):
Expand All @@ -230,21 +247,18 @@ def _calc_log_probs(self, list_of_action_batches):

return log_probs

def sample_actions_log_probs(self, action_mask=None):
action_mask = [action_mask[i] if action_mask is not None else None for i in range(len(self.distributions))]
list_of_action_batches = [d.sample(action_mask[i]) for i, d in enumerate(self.distributions)]
def sample_actions_log_probs(self):
list_of_action_batches = [d.sample() for d in self.distributions]
batch_of_action_tuples = self._flatten_actions(list_of_action_batches)
log_probs = self._calc_log_probs(list_of_action_batches)
return batch_of_action_tuples, log_probs

def sample(self, action_mask=None):
action_mask = [action_mask[i] if action_mask is not None else None for i in range(len(self.distributions))]
list_of_action_batches = [d.sample(action_mask[i]) for i, d in enumerate(self.distributions)]
def sample(self):
list_of_action_batches = [d.sample() for d in self.distributions]
return self._flatten_actions(list_of_action_batches)

def argmax(self, action_mask=None):
action_mask = [action_mask[i] if action_mask is not None else None for i in range(len(self.distributions))]
list_of_action_batches = [argmax_actions(d, action_mask[i]) for i, d in enumerate(self.distributions)]
def argmax(self):
list_of_action_batches = [argmax_actions(d) for d in self.distributions]
return torch.cat(list_of_action_batches).unsqueeze(0)

def log_prob(self, actions):
Expand Down Expand Up @@ -297,9 +311,6 @@ def _init_impl(params: Tensor, stddev_min: float, stddev_max: float):
stddevs = torch.clamp(stddevs, stddev_min, stddev_max)
return means, log_std, stddevs

def sample(self, action_mask=None):
return super().sample()

def kl_divergence(self, other):
kl = torch.distributions.kl.kl_divergence(self, other)
return kl
Expand Down
2 changes: 1 addition & 1 deletion sample_factory/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def max_frames_reached(frames):

if cfg.eval_deterministic:
action_distribution = actor_critic.action_distribution()
actions = argmax_actions(action_distribution, action_mask)
actions = argmax_actions(action_distribution)

# actions shape should be [num_agents, num_actions] even if it's [1, 1]
if actions.ndim == 1:
Expand Down
12 changes: 8 additions & 4 deletions sample_factory/model/action_parameterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ def __init__(self, cfg, core_out_size, action_space):
num_action_outputs = calc_num_action_parameters(action_space)
self.distribution_linear = nn.Linear(core_out_size, num_action_outputs)

def forward(self, actor_core_output):
def forward(self, actor_core_output, action_mask=None):
"""Just forward the FC layer and generate the distribution object."""
action_distribution_params = self.distribution_linear(actor_core_output)
action_distribution = get_action_distribution(self.action_space, raw_logits=action_distribution_params)
action_distribution = get_action_distribution(
self.action_space, raw_logits=action_distribution_params, action_mask=action_mask
)
return action_distribution_params, action_distribution


Expand All @@ -58,7 +60,7 @@ def __init__(self, cfg, core_out_size, action_space):
initial_stddev.fill_(math.log(self.cfg.initial_stddev))
self.learned_stddev = nn.Parameter(initial_stddev, requires_grad=True)

def forward(self, actor_core_output: Tensor):
def forward(self, actor_core_output: Tensor, action_mask=None):
action_means = self.distribution_linear(actor_core_output)
if self.tanh_scale > 0:
# scale the action means to be in the range [-tanh_scale, tanh_scale]
Expand All @@ -68,5 +70,7 @@ def forward(self, actor_core_output: Tensor):
batch_size = action_means.shape[0]
action_stddevs = self.learned_stddev.repeat(batch_size, 1)
action_distribution_params = torch.cat((action_means, action_stddevs), dim=1)
action_distribution = get_action_distribution(self.action_space, raw_logits=action_distribution_params)
action_distribution = get_action_distribution(
self.action_space, raw_logits=action_distribution_params, action_mask=action_mask
)
return action_distribution_params, action_distribution
18 changes: 10 additions & 8 deletions sample_factory/model/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,10 @@ def summaries(self) -> Dict:
def action_distribution(self):
return self.last_action_distribution

def _maybe_sample_actions(
self, sample_actions: bool, result: TensorDict, action_mask: Optional[Tensor] = None
) -> None:
def _maybe_sample_actions(self, sample_actions: bool, result: TensorDict) -> None:
if sample_actions:
# for non-trivial action spaces it is faster to do these together
actions, result["log_prob_actions"] = sample_actions_log_probs(self.last_action_distribution, action_mask)
actions, result["log_prob_actions"] = sample_actions_log_probs(self.last_action_distribution)
assert actions.dim() == 2 # TODO: remove this once we test everything
result["actions"] = actions.squeeze(dim=1)

Expand Down Expand Up @@ -177,12 +175,14 @@ def forward_tail(
if values_only:
return result

action_distribution_params, self.last_action_distribution = self.action_parameterization(decoder_output)
action_distribution_params, self.last_action_distribution = self.action_parameterization(
decoder_output, action_mask
)

# `action_logits` is not the best name here, better would be "action distribution parameters"
result["action_logits"] = action_distribution_params

self._maybe_sample_actions(sample_actions, result, action_mask)
self._maybe_sample_actions(sample_actions, result)
return result

def forward(
Expand Down Expand Up @@ -303,11 +303,13 @@ def forward_tail(

# first core output corresponds to the actor
actor_decoder_output = self.actor_decoder(core_outputs[0])
action_distribution_params, self.last_action_distribution = self.action_parameterization(actor_decoder_output)
action_distribution_params, self.last_action_distribution = self.action_parameterization(
actor_decoder_output, action_mask
)

result["action_logits"] = action_distribution_params

self._maybe_sample_actions(sample_actions, result, action_mask)
self._maybe_sample_actions(sample_actions, result)
return result

def forward(
Expand Down
11 changes: 5 additions & 6 deletions tests/algo/test_action_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def test_simple_distribution(self, gym_space, batch_size, has_action_mask):
simple_num_logits = calc_num_action_parameters(simple_action_space)
assert simple_num_logits == simple_action_space.n

expected_actions, action_mask = generate_expected_actions(simple_action_space.n, batch_size, has_action_mask)
simple_logits = torch.rand(batch_size, simple_num_logits)
simple_action_distribution = get_action_distribution(simple_action_space, simple_logits)
simple_action_distribution = get_action_distribution(simple_action_space, simple_logits, action_mask)

expected_actions, action_mask = generate_expected_actions(simple_action_space.n, batch_size, has_action_mask)
simple_actions = simple_action_distribution.sample(action_mask)
simple_actions = simple_action_distribution.sample()
assert list(simple_actions.shape) == [batch_size, 1]
assert all(torch.isin(a, expected_actions) for a in simple_actions)

Expand Down Expand Up @@ -103,12 +103,11 @@ def test_tuple_distribution(self, num_spaces, gym_space, batch_size, has_action_

assert num_logits == sum(s.n for s in action_space.spaces)

action_distribution = get_action_distribution(action_space, logits)

expected_actions, action_mask = generate_expected_actions(gym_space.n, batch_size, has_action_mask)
action_mask = action_mask.repeat(num_spaces, 1) if action_mask is not None else None
action_distribution = get_action_distribution(action_space, logits, action_mask)

tuple_actions = action_distribution.sample(action_mask)
tuple_actions = action_distribution.sample()
assert list(tuple_actions.shape) == [batch_size, num_spaces]
assert all(torch.isin(a, expected_actions) for actions in tuple_actions for a in actions)

Expand Down

0 comments on commit 1a66a74

Please sign in to comment.