Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,12 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo
enc_mask = outputs.pop('encoder_mask')
decoder_input_ids = outputs.pop('decoder_input_ids')
batch = outputs.pop('batch')
if isinstance(batch, PromptedAudioToTextMiniBatch):
batch_audio = batch.audio
batch_audio_lens = batch.audio_lens
else:
# Handling TensorDataset / external DataLoader
batch_audio, batch_audio_lens = batch[0], batch[1]

del log_probs
num_chunks = enc_states.shape[0]
Expand All @@ -1041,13 +1047,17 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo
return_hypotheses=trcfg.return_hypotheses,
)
merge_to_be_done = trcfg.enable_chunking and len(hypotheses) > 1
if trcfg.enable_chunking:
assert isinstance(
batch, PromptedAudioToTextMiniBatch
), "Chunking is only supported with Canary dataloaders"

del enc_states, enc_mask, decoder_input_ids

if trcfg.timestamps and self.timestamps_asr_model is not None:
hypotheses = get_forced_aligned_timestamps_with_external_model(
audio=[audio.squeeze()[:audio_len] for audio, audio_len in zip(batch.audio, batch.audio_lens)],
batch_size=len(batch.audio),
audio=[audio.squeeze()[:audio_len] for audio, audio_len in zip(batch_audio, batch_audio_lens)],
batch_size=len(batch_audio),
external_ctc_model=self.timestamps_asr_model,
main_model_predictions=hypotheses,
timestamp_type='char' if merge_to_be_done else ['word', 'segment'],
Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/asr/parts/mixins/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,12 @@ def transcribe_generator(self, audio, override_config: Optional[TranscribeConfig
dataloader = self._transcribe_input_processing(audio, transcribe_cfg)
else:
dataloader = audio
assert isinstance(dataloader, DataLoader), "`dataloader` must be of type DataLoader at this point."

if (isinstance(audio, list) and isinstance(audio[0], (np.ndarray, torch.Tensor))) or isinstance(
audio, (np.ndarray, torch.Tensor)
):
transcribe_cfg.enable_chunking = False # Can't chunk tensors (don't know cuts)

if hasattr(transcribe_cfg, 'verbose'):
verbose = transcribe_cfg.verbose
Expand Down