Skip to content

Commit

Permalink
improve masked_softmax and masked_log_softmax for extreme cases
Browse files Browse the repository at this point in the history
  • Loading branch information
nkzawa committed Oct 6, 2024
1 parent 1a66a74 commit b46ee0f
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions sample_factory/algo/utils/action_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,17 @@ def argmax_actions(distribution):
raise NotImplementedError(f"Action distribution type {type(distribution)} does not support argmax!")


# Retrieved from AllenNLP:
# https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py#L243
def masked_softmax(logits, mask):
# To limit numerical errors from large vector elements outside the mask, we zero these out.
result = functional.softmax(logits * mask, dim=-1)
# Mask out the invalid logits by adding a large negative number (-1e9)
logits = logits + (mask == 0) * -1e9
result = functional.softmax(logits, dim=-1)
result = result * mask
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
return result


# Retrieved from AllenNLP:
# https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py#L286
def masked_log_softmax(logits, mask):
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it
# results in nans when the whole vector is masked. We need a very small value instead of a
# zero in the mask for these cases.
logits = logits + (mask + 1e-13).log()
logits = logits + (mask == 0) * -1e9
return functional.log_softmax(logits, dim=-1)


Expand Down

0 comments on commit b46ee0f

Please sign in to comment.