diff --git a/sample_factory/algo/utils/action_distributions.py b/sample_factory/algo/utils/action_distributions.py index 5d0006c37..b279aeb3a 100644 --- a/sample_factory/algo/utils/action_distributions.py +++ b/sample_factory/algo/utils/action_distributions.py @@ -81,23 +81,17 @@ def argmax_actions(distribution): raise NotImplementedError(f"Action distribution type {type(distribution)} does not support argmax!") -# Retrieved from AllenNLP: -# https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py#L243 def masked_softmax(logits, mask): - # To limit numerical errors from large vector elements outside the mask, we zero these out. - result = functional.softmax(logits * mask, dim=-1) + # Mask out the invalid logits by adding a large negative number (-1e9) + logits = logits + (mask == 0) * -1e9 + result = functional.softmax(logits, dim=-1) result = result * mask result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) return result -# Retrieved from AllenNLP: -# https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py#L286 def masked_log_softmax(logits, mask): - # vector + mask.log() is an easy way to zero out masked elements in logspace, but it - # results in nans when the whole vector is masked. We need a very small value instead of a - # zero in the mask for these cases. - logits = logits + (mask + 1e-13).log() + logits = logits + (mask == 0) * -1e9 return functional.log_softmax(logits, dim=-1)