Skip to content

Commit

Permalink
[BugFix] Force _get_batch_preds() to keep logits in decoder timestamp… (
Browse files Browse the repository at this point in the history
#6500)

* [BugFix] Force _get_batch_preds() to keep logits in decoder timestamps generator r1.18.0

Signed-off-by: Taejin Park <[email protected]>

* ignore keep_logits in FrameBatchASRLogits

Signed-off-by: Taejin Park <[email protected]>

---------

Signed-off-by: Taejin Park <[email protected]>
  • Loading branch information
tango4j authored Apr 28, 2023
1 parent bf4c8bb commit 5041fcd
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions nemo/collections/asr/parts/utils/decoder_timestamps_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def get_wer_feat_logit(audio_file_path, asr, frame_len, tokens_per_chunk, delay,
return hyp, tokens, log_prob


class FrameBatchASR_Logits(FrameBatchASR):
class FrameBatchASRLogits(FrameBatchASR):
"""
A class for streaming frame-based ASR.
Inherits from FrameBatchASR and adds new capability of returning the logit output.
Expand Down Expand Up @@ -260,10 +260,9 @@ def read_audio_file_and_return(self, audio_filepath: str, delay: float, model_st
self.set_frame_reader(frame_reader)

@torch.no_grad()
def _get_batch_preds(self):
def _get_batch_preds(self, keep_logits):
device = self.asr_model.device
for batch in iter(self.data_loader):

feat_signal, feat_signal_len = batch
feat_signal, feat_signal_len = feat_signal.to(device), feat_signal_len.to(device)
log_probs, encoded_len, predictions = self.asr_model(
Expand All @@ -272,9 +271,12 @@ def _get_batch_preds(self):
preds = torch.unbind(predictions)
for pred in preds:
self.all_preds.append(pred.cpu().numpy())
# Always keep logits in FrameBatchASRLogits
_ = keep_logits
log_probs_tup = torch.unbind(log_probs)
for log_prob in log_probs_tup:
self.all_logprobs.append(log_prob)
del log_probs, log_probs_tup
del encoded_len
del predictions

Expand Down Expand Up @@ -635,7 +637,7 @@ def run_ASR_BPE_CTC(self, asr_model: Type[EncDecCTCModelBPE]) -> Tuple[Dict, Dic
log_prediction=asr_model._cfg.get("log_prediction", False),
)

frame_asr = FrameBatchASR_Logits(
frame_asr = FrameBatchASRLogits(
asr_model=asr_model,
frame_len=self.chunk_len_in_sec,
total_buffer=self.total_buffer_in_secs,
Expand Down

0 comments on commit 5041fcd

Please sign in to comment.