From 2a6f1569205630d06ea1f95eb997dc9ad91ed6af Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Wed, 5 Jun 2024 14:44:36 -0700 Subject: [PATCH] Always cast softmax inputs to float32 when in training mode. While we don't need this for accurate results in b/float16, this is a safety precaution to make sure that training accuracy does not regress. Signed-off-by: Daniel Galvez --- .../collections/asr/parts/submodules/multi_head_attention.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index b687fccc9d4a..de86132a721b 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -669,7 +669,10 @@ def _compute_out_global_to_all( global_attn_scores = global_attn_scores.view(batch_size * self.h, max_num_global_attn_indices, seq_len) # compute global attn probs - global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1) + if self.training: + global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1, dtype=torch.float32) + else: + global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1) global_attn_probs = self.dropout(global_attn_probs_float)