Skip to content

Commit

Permalink
fix custom forward_torch_softmax (NVIDIA#6512) (NVIDIA#6517)
Browse files Browse the repository at this point in the history
Signed-off-by: Abhinav Khattar <[email protected]>
Co-authored-by: Abhinav Khattar <[email protected]>
Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
2 people authored and hsiehjackson committed Jun 2, 2023
1 parent 56ce2a6 commit b460716
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions nemo/collections/nlp/modules/common/megatron/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ def forward_torch_softmax(self, input, mask):
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
all_k_masked = mask.all(axis=-1)
zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None]
probs = probs * zero_attention_mask
if mask is not None:
all_k_masked = mask.all(axis=-1)
zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None]
probs = probs * zero_attention_mask

if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
Expand Down

0 comments on commit b460716

Please sign in to comment.