diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 5e4f7ac510d1..f890fb1cb849 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -1028,6 +1028,12 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo enc_mask = outputs.pop('encoder_mask') decoder_input_ids = outputs.pop('decoder_input_ids') batch = outputs.pop('batch') + if isinstance(batch, PromptedAudioToTextMiniBatch): + batch_audio = batch.audio + batch_audio_lens = batch.audio_lens + else: + # Handling TensorDataset / external DataLoader + batch_audio, batch_audio_lens = batch[0], batch[1] del log_probs num_chunks = enc_states.shape[0] @@ -1041,13 +1047,17 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo return_hypotheses=trcfg.return_hypotheses, ) merge_to_be_done = trcfg.enable_chunking and len(hypotheses) > 1 + if trcfg.enable_chunking: + assert isinstance( + batch, PromptedAudioToTextMiniBatch + ), "Chunking is only supported with Canary dataloaders" del enc_states, enc_mask, decoder_input_ids if trcfg.timestamps and self.timestamps_asr_model is not None: hypotheses = get_forced_aligned_timestamps_with_external_model( - audio=[audio.squeeze()[:audio_len] for audio, audio_len in zip(batch.audio, batch.audio_lens)], - batch_size=len(batch.audio), + audio=[audio.squeeze()[:audio_len] for audio, audio_len in zip(batch_audio, batch_audio_lens)], + batch_size=len(batch_audio), external_ctc_model=self.timestamps_asr_model, main_model_predictions=hypotheses, timestamp_type='char' if merge_to_be_done else ['word', 'segment'], diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index d5a6add4d2c0..b1b02ac1d245 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -356,6 +356,12 @@ def transcribe_generator(self, audio, override_config: Optional[TranscribeConfig dataloader = self._transcribe_input_processing(audio, transcribe_cfg) else: dataloader = audio + assert isinstance(dataloader, DataLoader), "`dataloader` must be of type DataLoader at this point." + + if (isinstance(audio, list) and isinstance(audio[0], (np.ndarray, torch.Tensor))) or isinstance( + audio, (np.ndarray, torch.Tensor) + ): + transcribe_cfg.enable_chunking = False # Can't chunk tensors (don't know cuts) if hasattr(transcribe_cfg, 'verbose'): verbose = transcribe_cfg.verbose