From e64fc6953962906393a9d7341df5fca86124fbf2 Mon Sep 17 00:00:00 2001 From: Timur <36272321+SN4KEBYTE@users.noreply.github.com> Date: Fri, 7 Apr 2023 00:30:26 +0700 Subject: [PATCH] Make possible to control tqdm progress bar in ASR models (#6375) * Added verbose param to control tqdm progress bar Signed-off-by: Timur Kasimov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Timur Kasimov Co-authored-by: Timur Kasimov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: hsiehjackson --- nemo/collections/asr/models/asr_model.py | 3 ++- nemo/collections/asr/models/ctc_models.py | 4 +++- nemo/collections/asr/models/hybrid_asr_tts_models.py | 4 ++-- nemo/collections/asr/models/hybrid_rnnt_ctc_models.py | 4 +++- nemo/collections/asr/models/k2_aligner_model.py | 9 +++++++-- nemo/collections/asr/models/rnnt_models.py | 4 +++- nemo/collections/asr/models/slu_models.py | 4 +++- 7 files changed, 23 insertions(+), 9 deletions(-) diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index b72c0ac5bcd8..c0f4c1cd0a70 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -30,11 +30,12 @@ class ASRModel(ModelPT, ABC): @abstractmethod - def transcribe(self, paths2audio_files: List[str], batch_size: int = 4) -> List[str]: + def transcribe(self, paths2audio_files: List[str], batch_size: int = 4, verbose: bool = True) -> List[str]: """ Takes paths to audio files and returns text transcription Args: paths2audio_files: paths to audio fragment to be transcribed + verbose: (bool) whether to display tqdm progress bar Returns: transcription texts diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 0b4f8a1db103..b7816ec5040d 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -119,6 +119,7 @@ def transcribe( num_workers: int = 0, channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, + verbose: bool = True, ) -> List[str]: """ If modify this function, please remember update transcribe_partial_audio() in @@ -138,6 +139,7 @@ def transcribe( num_workers: (int) number of workers for DataLoader channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. + verbose: (bool) whether to display tqdm progress bar Returns: A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files """ @@ -192,7 +194,7 @@ def transcribe( config['augmentor'] = augmentor temporary_datalayer = self._setup_transcribe_dataloader(config) - for test_batch in tqdm(temporary_datalayer, desc="Transcribing"): + for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose): logits, logits_len, greedy_predictions = self.forward( input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) ) diff --git a/nemo/collections/asr/models/hybrid_asr_tts_models.py b/nemo/collections/asr/models/hybrid_asr_tts_models.py index 65f4cd7c903a..23a98d13c404 100644 --- a/nemo/collections/asr/models/hybrid_asr_tts_models.py +++ b/nemo/collections/asr/models/hybrid_asr_tts_models.py @@ -339,9 +339,9 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): """Test epoch end hook for ASR model""" return self.asr_model.multi_test_epoch_end(outputs=outputs, dataloader_idx=dataloader_idx) - def transcribe(self, paths2audio_files: List[str], batch_size: int = 4) -> List[str]: + def transcribe(self, paths2audio_files: List[str], batch_size: int = 4, verbose: bool = True) -> List[str]: """Transcribe audio data using ASR model""" - return self.asr_model.transcribe(paths2audio_files=paths2audio_files, batch_size=batch_size) + return self.asr_model.transcribe(paths2audio_files=paths2audio_files, batch_size=batch_size, verbose=verbose) def setup_multiple_validation_data(self, val_data_config: Union[DictConfig, Dict]): """Setup multiple validation data for ASR model""" diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 9b59d0eb2ef8..e3acec2c7420 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -101,6 +101,7 @@ def transcribe( num_workers: int = 0, channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, + verbose: bool = True, ) -> (List[str], Optional[List['Hypothesis']]): """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. @@ -117,6 +118,7 @@ def transcribe( num_workers: (int) number of workers for DataLoader channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. + verbose: (bool) whether to display tqdm progress bar Returns: Returns a tuple of 2 items - @@ -182,7 +184,7 @@ def transcribe( config['augmentor'] = augmentor temporary_datalayer = self._setup_transcribe_dataloader(config) - for test_batch in tqdm(temporary_datalayer, desc="Transcribing"): + for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose): encoded, encoded_len = self.forward( input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) ) diff --git a/nemo/collections/asr/models/k2_aligner_model.py b/nemo/collections/asr/models/k2_aligner_model.py index 402cf68ff234..54d342df40e4 100644 --- a/nemo/collections/asr/models/k2_aligner_model.py +++ b/nemo/collections/asr/models/k2_aligner_model.py @@ -13,11 +13,13 @@ # limitations under the License. import copy +import os from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch from omegaconf import DictConfig, OmegaConf, open_dict +from tqdm.auto import tqdm from nemo.collections.asr.data.audio_to_ctm_dataset import FrameCtmUnit from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs @@ -536,7 +538,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0) -> List[Tuple[int, 'F return self._predict_impl(encoded, encoded_len, transcript, transcript_len, sample_id) @torch.no_grad() - def transcribe(self, manifest: List[str], batch_size: int = 4, num_workers: int = None,) -> List['FrameCtmUnit']: + def transcribe( + self, manifest: List[str], batch_size: int = 4, num_workers: int = None, verbose: bool = True, + ) -> List['FrameCtmUnit']: """ Does alignment. Use this method for debugging and prototyping. @@ -547,6 +551,7 @@ def transcribe(self, manifest: List[str], batch_size: int = 4, num_workers: int batch_size: (int) batch size to use during inference. \ Bigger will result in better throughput performance but would use more memory. num_workers: (int) number of workers for DataLoader + verbose: (bool) whether to display tqdm progress bar Returns: A list of four: (label, start_frame, length, probability), called FrameCtmUnit, \ @@ -582,7 +587,7 @@ def transcribe(self, manifest: List[str], batch_size: int = 4, num_workers: int 'num_workers': num_workers, } temporary_datalayer = self._model._setup_transcribe_dataloader(config) - for test_batch in tqdm(temporary_datalayer, desc="Aligning"): + for test_batch in tqdm(temporary_datalayer, desc="Aligning", disable=not verbose): test_batch[0] = test_batch[0].to(device) test_batch[1] = test_batch[1].to(device) hypotheses += [unit for i, unit in self.predict_step(test_batch, 0)] diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index cd9664a3d312..a3e36dbc1522 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -217,6 +217,7 @@ def transcribe( num_workers: int = 0, channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, + verbose: bool = True, ) -> Tuple[List[str], Optional[List['Hypothesis']]]: """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. @@ -233,6 +234,7 @@ def transcribe( num_workers: (int) number of workers for DataLoader channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. + verbose: (bool) whether to display tqdm progress bar Returns: Returns a tuple of 2 items - * A list of greedy transcript texts / Hypothesis @@ -283,7 +285,7 @@ def transcribe( config['augmentor'] = augmentor temporary_datalayer = self._setup_transcribe_dataloader(config) - for test_batch in tqdm(temporary_datalayer, desc="Transcribing"): + for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=True): encoded, encoded_len = self.forward( input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) ) diff --git a/nemo/collections/asr/models/slu_models.py b/nemo/collections/asr/models/slu_models.py index 1aafe0dddf96..2062397c511c 100644 --- a/nemo/collections/asr/models/slu_models.py +++ b/nemo/collections/asr/models/slu_models.py @@ -544,6 +544,7 @@ def transcribe( logprobs: bool = False, return_hypotheses: bool = False, num_workers: int = 0, + verbose: bool = True, ) -> List[str]: """ Uses greedy decoding to transcribe audio files into SLU semantics. @@ -559,6 +560,7 @@ def transcribe( return_hypotheses: (bool) Either return hypotheses or text With hypotheses can do some postprocessing like getting timestamp or rescoring num_workers: (int) number of workers for DataLoader + verbose: (bool) whether to display tqdm progress bar Returns: A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files @@ -607,7 +609,7 @@ def transcribe( } temporary_datalayer = self._setup_transcribe_dataloader(config) - for test_batch in tqdm(temporary_datalayer, desc="Transcribing"): + for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose): predictions = self.predict( input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) )