diff --git a/nemo/collections/nlp/modules/common/megatron/fused_softmax.py b/nemo/collections/nlp/modules/common/megatron/fused_softmax.py index 3dc0a00c55bd..2c914a67dd12 100644 --- a/nemo/collections/nlp/modules/common/megatron/fused_softmax.py +++ b/nemo/collections/nlp/modules/common/megatron/fused_softmax.py @@ -51,9 +51,10 @@ def forward_torch_softmax(self, input, mask): input = input * self.scale mask_output = self.mask_func(input, mask) if mask is not None else input probs = torch.nn.Softmax(dim=-1)(mask_output) - all_k_masked = mask.all(axis=-1) - zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None] - probs = probs * zero_attention_mask + if mask is not None: + all_k_masked = mask.all(axis=-1) + zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None] + probs = probs * zero_attention_mask if self.input_in_float16 and self.softmax_in_fp32: if self.input_in_fp16: