Skip to content

Commit

Permalink
Support dataloader as input to audio for transcription (NVIDIA#9201)
Browse files Browse the repository at this point in the history
* Support dataloader as input to `audio` for transcription

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

* Support dataloader as input to `audio` for transcription

Signed-off-by: smajumdar <[email protected]>

* Update transcribe signatures

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

---------

Signed-off-by: smajumdar <[email protected]>
Signed-off-by: titu1994 <[email protected]>
Signed-off-by: Boxiang Wang <[email protected]>
  • Loading branch information
titu1994 authored and BoxiangW committed Jun 5, 2024
1 parent ec98601 commit 6352ddc
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 44 deletions.
12 changes: 7 additions & 5 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader

from nemo.collections.asr.data.audio_to_text_lhotse_prompted import (
PromptedAudioToTextLhotseDataset,
Expand Down Expand Up @@ -156,7 +157,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.transf_encoder = EncDecMultiTaskModel.from_config_dict(transf_encoder_cfg_dict)

# Initialize weights
std_init_range = 1 / self.cfg.model_defaults.lm_enc_hidden ** 0.5
std_init_range = 1 / self.cfg.model_defaults.lm_enc_hidden**0.5
self.transf_encoder.apply(lambda module: transformer_weights_init(module, std_init_range))

transf_decoder_cfg_dict = cfg.transf_decoder
Expand All @@ -182,7 +183,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight

# Initialize weights
std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden ** 0.5
std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden**0.5
self.transf_decoder.apply(lambda module: transformer_weights_init(module, std_init_range))
self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range))

Expand Down Expand Up @@ -347,7 +348,7 @@ def change_vocabulary(
self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight

# Initialize weights of token classifier
std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden ** 0.5
std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden**0.5
self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range))

# Setup Decoding class
Expand Down Expand Up @@ -387,7 +388,7 @@ def change_vocabulary(
@torch.no_grad()
def transcribe(
self,
audio: Union[List[str], str],
audio: Union[str, List[str], np.ndarray, DataLoader],
batch_size: int = 4,
return_hypotheses: bool = False,
task: Optional[str] = None,
Expand All @@ -403,7 +404,8 @@ def transcribe(
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Args:
audio: (a list) of paths to audio files. \
audio: (a single or list) of paths to audio files or a np.ndarray audio array.
Can also be a dataloader object that provides values that can be consumed by the model.
Recommended length per file is between 5 and 25 seconds. \
But it is possible to pass a few hours long file if enough GPU memory is available.
batch_size: (int) batch size to use during inference.
Expand Down
20 changes: 13 additions & 7 deletions nemo/collections/asr/models/classification_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import copy
import json
import os
import tempfile
from abc import abstractmethod
from dataclasses import dataclass, field
from math import ceil, floor
Expand All @@ -24,6 +23,7 @@
import torch
from omegaconf import DictConfig, ListConfig, OmegaConf
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError

Expand Down Expand Up @@ -169,7 +169,8 @@ def forward(

if not has_processed_signal:
processed_signal, processed_signal_length = self.preprocessor(
input_signal=input_signal, length=input_signal_length,
input_signal=input_signal,
length=input_signal_length,
)
# Crop or pad is always applied
if self.crop_or_pad is not None:
Expand Down Expand Up @@ -355,7 +356,7 @@ def _setup_feature_label_dataloader(self, config: DictConfig) -> torch.utils.dat
@torch.no_grad()
def transcribe(
self,
audio: List[str],
audio: Union[List[str], DataLoader],
batch_size: int = 4,
logprobs=None,
override_config: Optional[ClassificationInferConfig] | Optional[RegressionInferConfig] = None,
Expand All @@ -364,7 +365,8 @@ def transcribe(
Generate class labels for provided audio files. Use this method for debugging and prototyping.
Args:
audio: (a single or list) of paths to audio files or a np.ndarray audio sample. \
audio: (a single or list) of paths to audio files or a np.ndarray audio array.
Can also be a dataloader object that provides values that can be consumed by the model.
Recommended length per file is approximately 1 second.
batch_size: (int) batch size to use during inference. \
Bigger will result in better throughput performance but would use more memory.
Expand Down Expand Up @@ -952,7 +954,10 @@ def _setup_dataloader_from_config(self, config: DictConfig):

shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
dataset = audio_to_label_dataset.get_tarred_audio_multi_label_dataset(
cfg=config, shuffle_n=shuffle_n, global_rank=self.global_rank, world_size=self.world_size,
cfg=config,
shuffle_n=shuffle_n,
global_rank=self.global_rank,
world_size=self.world_size,
)
shuffle = False
if hasattr(dataset, 'collate_fn'):
Expand Down Expand Up @@ -1022,7 +1027,8 @@ def forward(

if not has_processed_signal:
processed_signal, processed_signal_length = self.preprocessor(
input_signal=input_signal, length=input_signal_length,
input_signal=input_signal,
length=input_signal_length,
)

# Crop or pad is always applied
Expand Down Expand Up @@ -1124,7 +1130,7 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
def reshape_labels(self, logits, labels, logits_len, labels_len):
"""
Reshape labels to match logits shape. For example, each label is expected to cover a 40ms frame, while each frme prediction from the
model covers 20ms. If labels are shorter than logits, labels are repeated, otherwise labels are folded and argmax is applied to obtain
model covers 20ms. If labels are shorter than logits, labels are repeated, otherwise labels are folded and argmax is applied to obtain
the label of each frame. When lengths of labels and logits are not factors of each other, labels are truncated or padded with zeros.
The ratio_threshold=0.2 is used to determine whether to pad or truncate labels, where the value 0.2 is not important as in real cases the ratio
is very close to either ceil(ratio) or floor(ratio). We use 0.2 here for easier unit-testing. This implementation does not allow frame length
Expand Down
29 changes: 22 additions & 7 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from nemo.collections.asr.data import audio_to_text_dataset
Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

def transcribe(
self,
audio: Union[str, List[str], torch.Tensor, np.ndarray],
audio: Union[str, List[str], torch.Tensor, np.ndarray, DataLoader],
batch_size: int = 4,
return_hypotheses: bool = False,
num_workers: int = 0,
Expand All @@ -135,7 +136,8 @@ def transcribe(
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Args:
audio: (a single or list) of paths to audio files or a np.ndarray audio array. \
audio: (a single or list) of paths to audio files or a np.ndarray audio array.
Can also be a dataloader object that provides values that can be consumed by the model.
Recommended length per file is between 5 and 25 seconds. \
But it is possible to pass a few hours long file if enough GPU memory is available.
batch_size: (int) batch size to use during inference.
Expand Down Expand Up @@ -493,7 +495,8 @@ def forward(

if not has_processed_signal:
processed_signal, processed_signal_length = self.preprocessor(
input_signal=input_signal, length=input_signal_length,
input_signal=input_signal,
length=input_signal_length,
)

if self.spec_augmentation is not None and self.training:
Expand Down Expand Up @@ -579,7 +582,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len)

transcribed_texts, _ = self.wer.decoding.ctc_decoder_predictions_tensor(
decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False,
decoder_outputs=log_probs,
decoder_lengths=encoded_len,
return_hypotheses=False,
)

sample_id = sample_id.cpu().detach().numpy()
Expand All @@ -601,11 +606,19 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0):
log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len
)
loss_value, metrics = self.add_interctc_losses(
loss_value, transcript, transcript_len, compute_wer=True, log_wer_num_denom=True, log_prefix="val_",
loss_value,
transcript,
transcript_len,
compute_wer=True,
log_wer_num_denom=True,
log_prefix="val_",
)

self.wer.update(
predictions=log_probs, targets=transcript, targets_lengths=transcript_len, predictions_lengths=encoded_len,
predictions=log_probs,
targets=transcript,
targets_lengths=transcript_len,
predictions_lengths=encoded_len,
)
wer, wer_num, wer_denom = self.wer.compute()
self.wer.reset()
Expand Down Expand Up @@ -677,7 +690,9 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen
logits_len = outputs.pop('logits_len')

current_hypotheses, all_hyp = self.decoding.ctc_decoder_predictions_tensor(
logits, decoder_lengths=logits_len, return_hypotheses=trcfg.return_hypotheses,
logits,
decoder_lengths=logits_len,
return_hypotheses=trcfg.return_hypotheses,
)
if trcfg.return_hypotheses:
if logits.is_cuda:
Expand Down
12 changes: 9 additions & 3 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def transcribe(
Args:
audio: (a list) of paths to audio files. \
audio: (a single or list) of paths to audio files or a np.ndarray audio array.
Can also be a dataloader object that provides values that can be consumed by the model.
Recommended length per file is between 5 and 25 seconds. \
But it is possible to pass a few hours long file if enough GPU memory is available.
batch_size: (int) batch size to use during inference. \
Expand Down Expand Up @@ -182,7 +183,9 @@ def _transcribe_output_processing(
encoded_len = outputs.pop('encoded_len')

best_hyp, all_hyp = self.ctc_decoding.ctc_decoder_predictions_tensor(
logits, encoded_len, return_hypotheses=trcfg.return_hypotheses,
logits,
encoded_len,
return_hypotheses=trcfg.return_hypotheses,
)
logits = logits.cpu()

Expand Down Expand Up @@ -554,7 +557,10 @@ def validation_pass(self, batch, batch_idx, dataloader_idx):
loss_value = (1 - self.ctc_loss_weight) * loss_value + self.ctc_loss_weight * ctc_loss
tensorboard_logs['val_loss'] = loss_value
self.ctc_wer.update(
predictions=log_probs, targets=transcript, targets_lengths=transcript_len, predictions_lengths=encoded_len,
predictions=log_probs,
targets=transcript,
targets_lengths=transcript_len,
predictions_lengths=encoded_len,
)
ctc_wer, ctc_wer_num, ctc_wer_denom = self.ctc_wer.compute()
self.ctc_wer.reset()
Expand Down
28 changes: 19 additions & 9 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
# limitations under the License.

import copy
import json
import os
import tempfile
from math import ceil
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.data.audio_to_text import _AudioTextDataset
Expand Down Expand Up @@ -101,7 +100,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.cfg.decoding = self.set_decoding_type_according_to_loss(self.cfg.decoding)
# Setup decoding objects
self.decoding = RNNTDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary,
decoding_cfg=self.cfg.decoding,
decoder=self.decoder,
joint=self.joint,
vocabulary=self.joint.vocabulary,
)
# Setup WER calculation
self.wer = WER(
Expand Down Expand Up @@ -236,7 +238,7 @@ def set_decoding_type_according_to_loss(self, decoding_cfg):
@torch.no_grad()
def transcribe(
self,
audio: List[str],
audio: Union[str, List[str], np.ndarray, DataLoader],
batch_size: int = 4,
return_hypotheses: bool = False,
partial_hypothesis: Optional[List['Hypothesis']] = None,
Expand All @@ -250,7 +252,8 @@ def transcribe(
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Args:
audio: (a list) of paths to audio files. \
audio: (a single or list) of paths to audio files or a np.ndarray audio array.
Can also be a dataloader object that provides values that can be consumed by the model.
Recommended length per file is between 5 and 25 seconds. \
But it is possible to pass a few hours long file if enough GPU memory is available.
batch_size: (int) batch size to use during inference. \
Expand Down Expand Up @@ -338,7 +341,10 @@ def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[Di
decoding_cfg = self.set_decoding_type_according_to_loss(decoding_cfg)

self.decoding = RNNTDecoding(
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary,
decoding_cfg=decoding_cfg,
decoder=self.decoder,
joint=self.joint,
vocabulary=self.joint.vocabulary,
)

self.wer = WER(
Expand Down Expand Up @@ -394,7 +400,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig):
decoding_cfg = self.set_decoding_type_according_to_loss(decoding_cfg)

self.decoding = RNNTDecoding(
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary,
decoding_cfg=decoding_cfg,
decoder=self.decoder,
joint=self.joint,
vocabulary=self.joint.vocabulary,
)

self.wer = WER(
Expand Down Expand Up @@ -649,7 +658,8 @@ def forward(

if not has_processed_signal:
processed_signal, processed_signal_length = self.preprocessor(
input_signal=input_signal, length=input_signal_length,
input_signal=input_signal,
length=input_signal_length,
)

# Spec augment is not applied during evaluation/testing
Expand Down
15 changes: 8 additions & 7 deletions nemo/collections/asr/models/slu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import tempfile
from math import ceil
from typing import Any, Dict, List, Optional, Union

import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs
Expand Down Expand Up @@ -190,7 +188,8 @@ def forward(

if not has_processed_signal:
processed_signal, processed_signal_length = self.preprocessor(
input_signal=input_signal, length=input_signal_length,
input_signal=input_signal,
length=input_signal_length,
)

if self.spec_augmentation is not None and self.training:
Expand Down Expand Up @@ -278,7 +277,8 @@ def predict(

if not has_processed_signal:
processed_signal, processed_signal_length = self.preprocessor(
input_signal=input_signal, length=input_signal_length,
input_signal=input_signal,
length=input_signal_length,
)

if self.spec_augmentation is not None and self.training:
Expand Down Expand Up @@ -560,7 +560,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
@torch.no_grad()
def transcribe(
self,
audio: List[str],
audio: Union[List[str], DataLoader],
batch_size: int = 4,
return_hypotheses: bool = False,
num_workers: int = 0,
Expand All @@ -571,7 +571,8 @@ def transcribe(
Use this method for debugging and prototyping.
Args:
audio: (a list) of paths to audio files. \
audio: (a single or list) of paths to audio files or a np.ndarray audio array.
Can also be a dataloader object that provides values that can be consumed by the model.
Recommended length per file is between 5 and 25 seconds. \
But it is possible to pass a few hours long file if enough GPU memory is available.
batch_size: (int) batch size to use during inference.
Expand Down
Loading

0 comments on commit 6352ddc

Please sign in to comment.