Skip to content

Commit

Permalink
Make possible to control tqdm progress bar in ASR models (NVIDIA#6375)
Browse files Browse the repository at this point in the history
* Added verbose param to control tqdm progress bar

Signed-off-by: Timur Kasimov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Timur Kasimov <[email protected]>
Co-authored-by: Timur Kasimov <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
3 people authored and hsiehjackson committed Jun 2, 2023
1 parent 2faaf54 commit e64fc69
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 9 deletions.
3 changes: 2 additions & 1 deletion nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@

class ASRModel(ModelPT, ABC):
@abstractmethod
def transcribe(self, paths2audio_files: List[str], batch_size: int = 4) -> List[str]:
def transcribe(self, paths2audio_files: List[str], batch_size: int = 4, verbose: bool = True) -> List[str]:
"""
Takes paths to audio files and returns text transcription
Args:
paths2audio_files: paths to audio fragment to be transcribed
verbose: (bool) whether to display tqdm progress bar
Returns:
transcription texts
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def transcribe(
num_workers: int = 0,
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
verbose: bool = True,
) -> List[str]:
"""
If modify this function, please remember update transcribe_partial_audio() in
Expand All @@ -138,6 +139,7 @@ def transcribe(
num_workers: (int) number of workers for DataLoader
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`.
augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied.
verbose: (bool) whether to display tqdm progress bar
Returns:
A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files
"""
Expand Down Expand Up @@ -192,7 +194,7 @@ def transcribe(
config['augmentor'] = augmentor

temporary_datalayer = self._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing"):
for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose):
logits, logits_len, greedy_predictions = self.forward(
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)
)
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/hybrid_asr_tts_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,9 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
"""Test epoch end hook for ASR model"""
return self.asr_model.multi_test_epoch_end(outputs=outputs, dataloader_idx=dataloader_idx)

def transcribe(self, paths2audio_files: List[str], batch_size: int = 4) -> List[str]:
def transcribe(self, paths2audio_files: List[str], batch_size: int = 4, verbose: bool = True) -> List[str]:
"""Transcribe audio data using ASR model"""
return self.asr_model.transcribe(paths2audio_files=paths2audio_files, batch_size=batch_size)
return self.asr_model.transcribe(paths2audio_files=paths2audio_files, batch_size=batch_size, verbose=verbose)

def setup_multiple_validation_data(self, val_data_config: Union[DictConfig, Dict]):
"""Setup multiple validation data for ASR model"""
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def transcribe(
num_workers: int = 0,
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
verbose: bool = True,
) -> (List[str], Optional[List['Hypothesis']]):
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Expand All @@ -117,6 +118,7 @@ def transcribe(
num_workers: (int) number of workers for DataLoader
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied.
verbose: (bool) whether to display tqdm progress bar
Returns:
Returns a tuple of 2 items -
Expand Down Expand Up @@ -182,7 +184,7 @@ def transcribe(
config['augmentor'] = augmentor

temporary_datalayer = self._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing"):
for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose):
encoded, encoded_len = self.forward(
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)
)
Expand Down
9 changes: 7 additions & 2 deletions nemo/collections/asr/models/k2_aligner_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

import copy
import os
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from tqdm.auto import tqdm

from nemo.collections.asr.data.audio_to_ctm_dataset import FrameCtmUnit
from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs
Expand Down Expand Up @@ -536,7 +538,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0) -> List[Tuple[int, 'F
return self._predict_impl(encoded, encoded_len, transcript, transcript_len, sample_id)

@torch.no_grad()
def transcribe(self, manifest: List[str], batch_size: int = 4, num_workers: int = None,) -> List['FrameCtmUnit']:
def transcribe(
self, manifest: List[str], batch_size: int = 4, num_workers: int = None, verbose: bool = True,
) -> List['FrameCtmUnit']:
"""
Does alignment. Use this method for debugging and prototyping.
Expand All @@ -547,6 +551,7 @@ def transcribe(self, manifest: List[str], batch_size: int = 4, num_workers: int
batch_size: (int) batch size to use during inference. \
Bigger will result in better throughput performance but would use more memory.
num_workers: (int) number of workers for DataLoader
verbose: (bool) whether to display tqdm progress bar
Returns:
A list of four: (label, start_frame, length, probability), called FrameCtmUnit, \
Expand Down Expand Up @@ -582,7 +587,7 @@ def transcribe(self, manifest: List[str], batch_size: int = 4, num_workers: int
'num_workers': num_workers,
}
temporary_datalayer = self._model._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Aligning"):
for test_batch in tqdm(temporary_datalayer, desc="Aligning", disable=not verbose):
test_batch[0] = test_batch[0].to(device)
test_batch[1] = test_batch[1].to(device)
hypotheses += [unit for i, unit in self.predict_step(test_batch, 0)]
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def transcribe(
num_workers: int = 0,
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
verbose: bool = True,
) -> Tuple[List[str], Optional[List['Hypothesis']]]:
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Expand All @@ -233,6 +234,7 @@ def transcribe(
num_workers: (int) number of workers for DataLoader
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied.
verbose: (bool) whether to display tqdm progress bar
Returns:
Returns a tuple of 2 items -
* A list of greedy transcript texts / Hypothesis
Expand Down Expand Up @@ -283,7 +285,7 @@ def transcribe(
config['augmentor'] = augmentor

temporary_datalayer = self._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing"):
for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=True):
encoded, encoded_len = self.forward(
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)
)
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/asr/models/slu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ def transcribe(
logprobs: bool = False,
return_hypotheses: bool = False,
num_workers: int = 0,
verbose: bool = True,
) -> List[str]:
"""
Uses greedy decoding to transcribe audio files into SLU semantics.
Expand All @@ -559,6 +560,7 @@ def transcribe(
return_hypotheses: (bool) Either return hypotheses or text
With hypotheses can do some postprocessing like getting timestamp or rescoring
num_workers: (int) number of workers for DataLoader
verbose: (bool) whether to display tqdm progress bar
Returns:
A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files
Expand Down Expand Up @@ -607,7 +609,7 @@ def transcribe(
}

temporary_datalayer = self._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing"):
for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose):
predictions = self.predict(
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)
)
Expand Down

0 comments on commit e64fc69

Please sign in to comment.