From 5041fcd57ccf6a967852d7860de92fee02538729 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Fri, 28 Apr 2023 09:39:09 -0700 Subject: [PATCH] =?UTF-8?q?[BugFix]=20Force=20=5Fget=5Fbatch=5Fpreds()=20t?= =?UTF-8?q?o=20keep=20logits=20in=20decoder=20timestamp=E2=80=A6=20(#6500)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [BugFix] Force _get_batch_preds() to keep logits in decoder timestamps generator r1.18.0 Signed-off-by: Taejin Park * ignore keep_logits in FrameBatchASRLogits Signed-off-by: Taejin Park --------- Signed-off-by: Taejin Park --- .../asr/parts/utils/decoder_timestamps_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py b/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py index 8e81d49939cb..f26b0c6b701a 100644 --- a/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py +++ b/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py @@ -232,7 +232,7 @@ def get_wer_feat_logit(audio_file_path, asr, frame_len, tokens_per_chunk, delay, return hyp, tokens, log_prob -class FrameBatchASR_Logits(FrameBatchASR): +class FrameBatchASRLogits(FrameBatchASR): """ A class for streaming frame-based ASR. Inherits from FrameBatchASR and adds new capability of returning the logit output. @@ -260,10 +260,9 @@ def read_audio_file_and_return(self, audio_filepath: str, delay: float, model_st self.set_frame_reader(frame_reader) @torch.no_grad() - def _get_batch_preds(self): + def _get_batch_preds(self, keep_logits): device = self.asr_model.device for batch in iter(self.data_loader): - feat_signal, feat_signal_len = batch feat_signal, feat_signal_len = feat_signal.to(device), feat_signal_len.to(device) log_probs, encoded_len, predictions = self.asr_model( @@ -272,9 +271,12 @@ def _get_batch_preds(self): preds = torch.unbind(predictions) for pred in preds: self.all_preds.append(pred.cpu().numpy()) + # Always keep logits in FrameBatchASRLogits + _ = keep_logits log_probs_tup = torch.unbind(log_probs) for log_prob in log_probs_tup: self.all_logprobs.append(log_prob) + del log_probs, log_probs_tup del encoded_len del predictions @@ -635,7 +637,7 @@ def run_ASR_BPE_CTC(self, asr_model: Type[EncDecCTCModelBPE]) -> Tuple[Dict, Dic log_prediction=asr_model._cfg.get("log_prediction", False), ) - frame_asr = FrameBatchASR_Logits( + frame_asr = FrameBatchASRLogits( asr_model=asr_model, frame_len=self.chunk_len_in_sec, total_buffer=self.total_buffer_in_secs,