Skip to content

Commit

Permalink
[ASR]:fixed augmentor argument for transcribe functionality of Hybrid…
Browse files Browse the repository at this point in the history
… CTC-RNNT model
  • Loading branch information
KunalDhawan committed Mar 24, 2023
1 parent 7ea8230 commit caee365
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def transcribe(
partial_hypothesis: Optional[List['Hypothesis']] = None,
num_workers: int = 0,
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
) -> (List[str], Optional[List['Hypothesis']]):
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Expand All @@ -112,6 +113,7 @@ def transcribe(
With hypotheses can do some postprocessing like getting timestamp or rescoring
num_workers: (int) number of workers for DataLoader
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied.
Returns:
Returns a tuple of 2 items -
Expand All @@ -126,6 +128,7 @@ def transcribe(
partial_hypothesis=partial_hypothesis,
num_workers=num_workers,
channel_selector=channel_selector,
augmentor=augmentor,
)

if paths2audio_files is None or len(paths2audio_files) == 0:
Expand Down Expand Up @@ -172,6 +175,9 @@ def transcribe(
'channel_selector': channel_selector,
}

if augmentor:
config['augmentor'] = augmentor

temporary_datalayer = self._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing"):
encoded, encoded_len = self.forward(
Expand Down

0 comments on commit caee365

Please sign in to comment.