From 6352ddc48fcc7adfc4248fe7daf6703a23e30c86 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Fri, 17 May 2024 09:18:36 -0700 Subject: [PATCH] Support dataloader as input to `audio` for transcription (#9201) * Support dataloader as input to `audio` for transcription Signed-off-by: smajumdar * Apply isort and black reformatting Signed-off-by: titu1994 * Support dataloader as input to `audio` for transcription Signed-off-by: smajumdar * Update transcribe signatures Signed-off-by: smajumdar * Apply isort and black reformatting Signed-off-by: titu1994 --------- Signed-off-by: smajumdar Signed-off-by: titu1994 Signed-off-by: Boxiang Wang --- .../asr/models/aed_multitask_models.py | 12 +++-- .../asr/models/classification_models.py | 20 +++++--- nemo/collections/asr/models/ctc_models.py | 29 +++++++++--- .../asr/models/hybrid_rnnt_ctc_models.py | 12 +++-- nemo/collections/asr/models/rnnt_models.py | 28 +++++++---- nemo/collections/asr/models/slu_models.py | 15 +++--- .../asr/models/transformer_bpe_models.py | 12 +++-- .../asr/parts/mixins/transcription.py | 9 +++- .../asr/mixins/test_transcription.py | 46 +++++++++++++++++++ 9 files changed, 139 insertions(+), 44 deletions(-) diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index f9413a4dd7383..b11d744a7e6a7 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -21,6 +21,7 @@ import torch from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer +from torch.utils.data import DataLoader from nemo.collections.asr.data.audio_to_text_lhotse_prompted import ( PromptedAudioToTextLhotseDataset, @@ -156,7 +157,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.transf_encoder = EncDecMultiTaskModel.from_config_dict(transf_encoder_cfg_dict) # Initialize weights - std_init_range = 1 / self.cfg.model_defaults.lm_enc_hidden ** 0.5 + std_init_range = 1 / self.cfg.model_defaults.lm_enc_hidden**0.5 self.transf_encoder.apply(lambda module: transformer_weights_init(module, std_init_range)) transf_decoder_cfg_dict = cfg.transf_decoder @@ -182,7 +183,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight # Initialize weights - std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden ** 0.5 + std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden**0.5 self.transf_decoder.apply(lambda module: transformer_weights_init(module, std_init_range)) self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range)) @@ -347,7 +348,7 @@ def change_vocabulary( self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight # Initialize weights of token classifier - std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden ** 0.5 + std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden**0.5 self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range)) # Setup Decoding class @@ -387,7 +388,7 @@ def change_vocabulary( @torch.no_grad() def transcribe( self, - audio: Union[List[str], str], + audio: Union[str, List[str], np.ndarray, DataLoader], batch_size: int = 4, return_hypotheses: bool = False, task: Optional[str] = None, @@ -403,7 +404,8 @@ def transcribe( """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. Args: - audio: (a list) of paths to audio files. \ + audio: (a single or list) of paths to audio files or a np.ndarray audio array. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. batch_size: (int) batch size to use during inference. diff --git a/nemo/collections/asr/models/classification_models.py b/nemo/collections/asr/models/classification_models.py index c1294de5bdc08..7b226f59e364b 100644 --- a/nemo/collections/asr/models/classification_models.py +++ b/nemo/collections/asr/models/classification_models.py @@ -15,7 +15,6 @@ import copy import json import os -import tempfile from abc import abstractmethod from dataclasses import dataclass, field from math import ceil, floor @@ -24,6 +23,7 @@ import torch from omegaconf import DictConfig, ListConfig, OmegaConf from pytorch_lightning import Trainer +from torch.utils.data import DataLoader from torchmetrics import Accuracy from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError @@ -169,7 +169,8 @@ def forward( if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) # Crop or pad is always applied if self.crop_or_pad is not None: @@ -355,7 +356,7 @@ def _setup_feature_label_dataloader(self, config: DictConfig) -> torch.utils.dat @torch.no_grad() def transcribe( self, - audio: List[str], + audio: Union[List[str], DataLoader], batch_size: int = 4, logprobs=None, override_config: Optional[ClassificationInferConfig] | Optional[RegressionInferConfig] = None, @@ -364,7 +365,8 @@ def transcribe( Generate class labels for provided audio files. Use this method for debugging and prototyping. Args: - audio: (a single or list) of paths to audio files or a np.ndarray audio sample. \ + audio: (a single or list) of paths to audio files or a np.ndarray audio array. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is approximately 1 second. batch_size: (int) batch size to use during inference. \ Bigger will result in better throughput performance but would use more memory. @@ -952,7 +954,10 @@ def _setup_dataloader_from_config(self, config: DictConfig): shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0 dataset = audio_to_label_dataset.get_tarred_audio_multi_label_dataset( - cfg=config, shuffle_n=shuffle_n, global_rank=self.global_rank, world_size=self.world_size, + cfg=config, + shuffle_n=shuffle_n, + global_rank=self.global_rank, + world_size=self.world_size, ) shuffle = False if hasattr(dataset, 'collate_fn'): @@ -1022,7 +1027,8 @@ def forward( if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) # Crop or pad is always applied @@ -1124,7 +1130,7 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): def reshape_labels(self, logits, labels, logits_len, labels_len): """ Reshape labels to match logits shape. For example, each label is expected to cover a 40ms frame, while each frme prediction from the - model covers 20ms. If labels are shorter than logits, labels are repeated, otherwise labels are folded and argmax is applied to obtain + model covers 20ms. If labels are shorter than logits, labels are repeated, otherwise labels are folded and argmax is applied to obtain the label of each frame. When lengths of labels and logits are not factors of each other, labels are truncated or padded with zeros. The ratio_threshold=0.2 is used to determine whether to pad or truncate labels, where the value 0.2 is not important as in real cases the ratio is very close to either ceil(ratio) or floor(ratio). We use 0.2 here for easier unit-testing. This implementation does not allow frame length diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 4df02b1177cd9..177da81f85f28 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -22,6 +22,7 @@ import torch from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer +from torch.utils.data import DataLoader from tqdm.auto import tqdm from nemo.collections.asr.data import audio_to_text_dataset @@ -119,7 +120,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): def transcribe( self, - audio: Union[str, List[str], torch.Tensor, np.ndarray], + audio: Union[str, List[str], torch.Tensor, np.ndarray, DataLoader], batch_size: int = 4, return_hypotheses: bool = False, num_workers: int = 0, @@ -135,7 +136,8 @@ def transcribe( Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. Args: - audio: (a single or list) of paths to audio files or a np.ndarray audio array. \ + audio: (a single or list) of paths to audio files or a np.ndarray audio array. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. batch_size: (int) batch size to use during inference. @@ -493,7 +495,8 @@ def forward( if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) if self.spec_augmentation is not None and self.training: @@ -579,7 +582,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) transcribed_texts, _ = self.wer.decoding.ctc_decoder_predictions_tensor( - decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False, + decoder_outputs=log_probs, + decoder_lengths=encoded_len, + return_hypotheses=False, ) sample_id = sample_id.cpu().detach().numpy() @@ -601,11 +606,19 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0): log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len ) loss_value, metrics = self.add_interctc_losses( - loss_value, transcript, transcript_len, compute_wer=True, log_wer_num_denom=True, log_prefix="val_", + loss_value, + transcript, + transcript_len, + compute_wer=True, + log_wer_num_denom=True, + log_prefix="val_", ) self.wer.update( - predictions=log_probs, targets=transcript, targets_lengths=transcript_len, predictions_lengths=encoded_len, + predictions=log_probs, + targets=transcript, + targets_lengths=transcript_len, + predictions_lengths=encoded_len, ) wer, wer_num, wer_denom = self.wer.compute() self.wer.reset() @@ -677,7 +690,9 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen logits_len = outputs.pop('logits_len') current_hypotheses, all_hyp = self.decoding.ctc_decoder_predictions_tensor( - logits, decoder_lengths=logits_len, return_hypotheses=trcfg.return_hypotheses, + logits, + decoder_lengths=logits_len, + return_hypotheses=trcfg.return_hypotheses, ) if trcfg.return_hypotheses: if logits.is_cuda: diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 3eaab9961ef8d..9a5c4188aebd2 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -111,7 +111,8 @@ def transcribe( Args: - audio: (a list) of paths to audio files. \ + audio: (a single or list) of paths to audio files or a np.ndarray audio array. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. batch_size: (int) batch size to use during inference. \ @@ -182,7 +183,9 @@ def _transcribe_output_processing( encoded_len = outputs.pop('encoded_len') best_hyp, all_hyp = self.ctc_decoding.ctc_decoder_predictions_tensor( - logits, encoded_len, return_hypotheses=trcfg.return_hypotheses, + logits, + encoded_len, + return_hypotheses=trcfg.return_hypotheses, ) logits = logits.cpu() @@ -554,7 +557,10 @@ def validation_pass(self, batch, batch_idx, dataloader_idx): loss_value = (1 - self.ctc_loss_weight) * loss_value + self.ctc_loss_weight * ctc_loss tensorboard_logs['val_loss'] = loss_value self.ctc_wer.update( - predictions=log_probs, targets=transcript, targets_lengths=transcript_len, predictions_lengths=encoded_len, + predictions=log_probs, + targets=transcript, + targets_lengths=transcript_len, + predictions_lengths=encoded_len, ) ctc_wer, ctc_wer_num, ctc_wer_denom = self.ctc_wer.compute() self.ctc_wer.reset() diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 386f2a9151420..cb2505fbadbff 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -13,16 +13,15 @@ # limitations under the License. import copy -import json import os -import tempfile from math import ceil from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np import torch from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer -from tqdm.auto import tqdm +from torch.utils.data import DataLoader from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text import _AudioTextDataset @@ -101,7 +100,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.cfg.decoding = self.set_decoding_type_according_to_loss(self.cfg.decoding) # Setup decoding objects self.decoding = RNNTDecoding( - decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, + decoding_cfg=self.cfg.decoding, + decoder=self.decoder, + joint=self.joint, + vocabulary=self.joint.vocabulary, ) # Setup WER calculation self.wer = WER( @@ -236,7 +238,7 @@ def set_decoding_type_according_to_loss(self, decoding_cfg): @torch.no_grad() def transcribe( self, - audio: List[str], + audio: Union[str, List[str], np.ndarray, DataLoader], batch_size: int = 4, return_hypotheses: bool = False, partial_hypothesis: Optional[List['Hypothesis']] = None, @@ -250,7 +252,8 @@ def transcribe( Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. Args: - audio: (a list) of paths to audio files. \ + audio: (a single or list) of paths to audio files or a np.ndarray audio array. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. batch_size: (int) batch size to use during inference. \ @@ -338,7 +341,10 @@ def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[Di decoding_cfg = self.set_decoding_type_according_to_loss(decoding_cfg) self.decoding = RNNTDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, + decoding_cfg=decoding_cfg, + decoder=self.decoder, + joint=self.joint, + vocabulary=self.joint.vocabulary, ) self.wer = WER( @@ -394,7 +400,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): decoding_cfg = self.set_decoding_type_according_to_loss(decoding_cfg) self.decoding = RNNTDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, + decoding_cfg=decoding_cfg, + decoder=self.decoder, + joint=self.joint, + vocabulary=self.joint.vocabulary, ) self.wer = WER( @@ -649,7 +658,8 @@ def forward( if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) # Spec augment is not applied during evaluation/testing diff --git a/nemo/collections/asr/models/slu_models.py b/nemo/collections/asr/models/slu_models.py index 1303bbfde7ea2..c599b7f4272a7 100644 --- a/nemo/collections/asr/models/slu_models.py +++ b/nemo/collections/asr/models/slu_models.py @@ -13,15 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os -import tempfile from math import ceil from typing import Any, Dict, List, Optional, Union import torch from omegaconf import DictConfig, OmegaConf, open_dict -from tqdm.auto import tqdm +from torch.utils.data import DataLoader from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs @@ -190,7 +188,8 @@ def forward( if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) if self.spec_augmentation is not None and self.training: @@ -278,7 +277,8 @@ def predict( if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) if self.spec_augmentation is not None and self.training: @@ -560,7 +560,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo @torch.no_grad() def transcribe( self, - audio: List[str], + audio: Union[List[str], DataLoader], batch_size: int = 4, return_hypotheses: bool = False, num_workers: int = 0, @@ -571,7 +571,8 @@ def transcribe( Use this method for debugging and prototyping. Args: - audio: (a list) of paths to audio files. \ + audio: (a single or list) of paths to audio files or a np.ndarray audio array. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. batch_size: (int) batch size to use during inference. diff --git a/nemo/collections/asr/models/transformer_bpe_models.py b/nemo/collections/asr/models/transformer_bpe_models.py index 21a5f34b30380..e7e67f8fbb2f4 100644 --- a/nemo/collections/asr/models/transformer_bpe_models.py +++ b/nemo/collections/asr/models/transformer_bpe_models.py @@ -24,6 +24,7 @@ import torch.distributed as dist from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer +from torch.utils.data import DataLoader from torchmetrics.text import SacreBLEUScore from tqdm.auto import tqdm @@ -141,7 +142,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): num_layers=self.cfg.head.num_layers, ) self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight - std_init_range = 1 / self.transf_decoder.hidden_size ** 0.5 + std_init_range = 1 / self.transf_decoder.hidden_size**0.5 self.transf_decoder.apply(lambda module: transformer_weights_init(module, std_init_range)) self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range)) @@ -174,7 +175,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): @torch.no_grad() def transcribe( self, - audio: List[str], + audio: Union[List[str], DataLoader], batch_size: int = 4, return_hypotheses: bool = False, num_workers: int = 0, @@ -185,7 +186,8 @@ def transcribe( """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. Args: - audio: (a list) of paths to audio files. \ + audio: (a list) of paths to audio files. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. batch_size: (int) batch size to use during inference. @@ -225,7 +227,9 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): config, global_rank=self.global_rank, world_size=self.world_size, - dataset=LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer,), + dataset=LhotseSpeechToTextBpeDataset( + tokenizer=self.tokenizer, + ), ) dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config( diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index c252d498dc081..df8d6bac50a97 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -186,7 +186,7 @@ class TranscriptionMixin(ABC): @torch.no_grad() def transcribe( self, - audio: Union[str, List[str], np.ndarray], + audio: Union[str, List[str], np.ndarray, DataLoader], batch_size: int = 4, return_hypotheses: bool = False, num_workers: int = 0, @@ -201,6 +201,7 @@ def transcribe( Args: audio: (a single or list) of paths to audio files or a np.ndarray audio array. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. But it is possible to pass a few hours long file if enough GPU memory is available. batch_size: (int) batch size to use during inference. @@ -368,7 +369,11 @@ def transcribe_generator(self, audio, override_config: Optional[TranscribeConfig with tempfile.TemporaryDirectory() as tmpdir: transcribe_cfg._internal.temp_dir = tmpdir - dataloader = self._transcribe_input_processing(audio, transcribe_cfg) + # Create a DataLoader if not already present + if not isinstance(audio, DataLoader): + dataloader = self._transcribe_input_processing(audio, transcribe_cfg) + else: + dataloader = audio if hasattr(transcribe_cfg, 'verbose'): verbose = transcribe_cfg.verbose diff --git a/tests/collections/asr/mixins/test_transcription.py b/tests/collections/asr/mixins/test_transcription.py index 794213c723977..1a6f38681d0cc 100644 --- a/tests/collections/asr/mixins/test_transcription.py +++ b/tests/collections/asr/mixins/test_transcription.py @@ -22,6 +22,7 @@ import torch from torch.utils.data import DataLoader, Dataset +from nemo.collections.asr.data.audio_to_text import _speech_collate_fn from nemo.collections.asr.models import ASRModel from nemo.collections.asr.parts.mixins import TranscribeConfig, TranscriptionMixin from nemo.collections.asr.parts.mixins.transcription import GenericTranscriptionType @@ -121,6 +122,27 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig): self.flag_end = True +class DummyDataset(Dataset): + def __init__(self, audio_tensors: List[str], config: Dict = None): + self.audio_tensors = audio_tensors + self.config = config + + def __getitem__(self, index): + data = self.audio_tensors[index] + samples = torch.tensor(data) + # Calculate seq length + seq_len = torch.tensor(samples.shape[0], dtype=torch.long) + + # Dummy text tokens + text_tokens = torch.tensor([0], dtype=torch.long) + text_tokens_len = torch.tensor(1, dtype=torch.long) + + return (samples, seq_len, text_tokens, text_tokens_len) + + def __len__(self): + return len(self.audio_tensors) + + @pytest.fixture() def dummy_model(): return TranscribableDummy() @@ -326,3 +348,27 @@ def test_transcribe_multiple_tensor(self, test_data_dir): assert len(outputs) == 2 assert isinstance(outputs[0], str) assert isinstance(outputs[1], str) + + @pytest.mark.with_downloads() + @pytest.mark.unit + def test_transcribe_dataloader(self, test_data_dir): + model = ASRModel.from_pretrained("stt_en_conformer_ctc_small") + + # Load audio file + import soundfile as sf + + audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") + audio, sr = sf.read(audio_file, dtype='float32') + + audio_file2 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an152-mwhw-b.wav") + audio2, sr = sf.read(audio_file2, dtype='float32') + + dataset = DummyDataset([audio, audio2]) + collate_fn = lambda x: _speech_collate_fn(x, pad_id=0) + dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_fn) + + # DataLoader test + outputs = model.transcribe(dataloader, batch_size=1) + assert len(outputs) == 2 + assert isinstance(outputs[0], str) + assert isinstance(outputs[1], str)