Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offline and streaming inference support for hybrid model #6570

Merged
merged 20 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@

import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf
from omegaconf import OmegaConf, open_dict
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved

from nemo.collections.asr.metrics.wer import CTCDecodingConfig
from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR
from nemo.collections.asr.parts.utils.transcribe_utils import (
Expand Down Expand Up @@ -78,10 +79,19 @@
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
random_seed: Optional[int] = None # seed number going to be used in seed_everything()

# Set to True to output greedy timestamp information (only supported models)
compute_timestamps: bool = False

# Set to True to output language ID information
compute_langs: bool = False

# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models",
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet and FasConformer models and 4 for Conformer models.

# Decoding strategy for CTC models
decoding: CTCDecodingConfig = CTCDecodingConfig()
VahidooX marked this conversation as resolved.
Show resolved Hide resolved

# Decoding strategy for CTC models
decoding: CTCDecodingConfig = CTCDecodingConfig()
Expand All @@ -108,6 +118,9 @@
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
torch.set_grad_enabled(False)

for key in cfg:
cfg[key] = None if cfg[key] == 'None' else cfg[key]

if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

Expand Down Expand Up @@ -174,6 +187,23 @@
)
return cfg

# Setup decoding strategy
if hasattr(asr_model, 'change_decoding_strategy'):
if not isinstance(asr_model, EncDecCTCModel) and not isinstance(asr_model, EncDecHybridRNNTCTCModel):
raise ValueError("The script supports ctc model and hybrid model with ctc decodng!")

else:
if cfg.compute_langs:
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("CTC models do not support `compute_langs` at the moment.")

if hasattr(
asr_model, 'cur_decoder'
): # hybrid model with ctc decoding or potential other models containing decoding switch feature
asr_model.change_decoding_strategy(cfg.decoding, decoder_type='ctc')

else: # ctc model
asr_model.change_decoding_strategy(cfg.decoding)

asr_model.eval()
asr_model = asr_model.to(asr_model.device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import torch
from omegaconf import OmegaConf, open_dict

from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.collections.asr.parts.utils.streaming_utils import (
BatchedFrameASRRNNT,
Expand Down Expand Up @@ -101,10 +102,16 @@ class TranscriptionConfig:
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
random_seed: Optional[int] = None # seed number going to be used in seed_everything()

# Set to True to output greedy timestamp information (only supported models)
compute_timestamps: bool = False

# Set to True to output language ID information
compute_langs: bool = False

# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet and FastConformer models and 4 for Conformer models.
VahidooX marked this conversation as resolved.
Show resolved Hide resolved

# Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
# device anyway, and do inference on CPU only if CUDA device is not found.
Expand Down Expand Up @@ -135,6 +142,9 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
torch.set_grad_enabled(False)

for key in cfg:
cfg[key] = None if cfg[key] == 'None' else cfg[key]

if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

Expand Down Expand Up @@ -203,9 +213,23 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
decoding_cfg.strategy = "greedy_batch"
decoding_cfg.preserve_alignments = True # required to compute the middle token for transducers.
decoding_cfg.fused_batch_size = -1 # temporarily stop fused batch during inference.
decoding_cfg.beam.return_best_hypothesis = True
decoding_cfg.beam.return_best_hypothesis = True # return and write the best hypothsis only
VahidooX marked this conversation as resolved.
Show resolved Hide resolved

# Setup decoding strategy
if hasattr(asr_model, 'change_decoding_strategy'):
if not isinstance(asr_model, EncDecRNNTModel) and not isinstance(asr_model, EncDecHybridRNNTCTCModel):
raise ValueError("The script supports rnnt model and hybrid model with rnnt decodng!")
else:
# rnnt model
if isinstance(asr_model, EncDecRNNTModel):
asr_model.change_decoding_strategy(decoding_cfg)

# hybrid ctc rnnt model with decoder_type = rnnt
if hasattr(asr_model, 'cur_decoder'):
asr_model.change_decoding_strategy(decoding_cfg, decoder_type='rnnt')

asr_model.change_decoding_strategy(decoding_cfg)
with open_dict(cfg):
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
cfg.decoding = decoding_cfg

with open_dict(cfg):
cfg.decoding = decoding_cfg
Expand Down
16 changes: 15 additions & 1 deletion examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,17 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
compute_timestamps = cfg.compute_timestamps
compute_langs = cfg.compute_langs

# Check whether model and decoder type match
if isinstance(asr_model, EncDecCTCModel):
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
if cfg.decoder_type and cfg.decoder_type != 'ctc':
raise ValueError('CTC model only support ctc decoding!')
elif isinstance(asr_model, EncDecHybridRNNTCTCModel):
if cfg.decoder_type and cfg.decoder_type not in ['ctc', 'rnnt']:
raise ValueError('Hybrid model only support ctc or rnnt decoding!')
else: # rnnt model, there could be other models needs to be addressed.
if cfg.decoder_type and cfg.decoder_type != 'rnnt':
raise ValueError('RNNT model only support rnnt decoding!')

# Setup decoding strategy
if hasattr(asr_model, 'change_decoding_strategy'):
if cfg.decoder_type is not None:
Expand All @@ -240,7 +251,10 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
decoding_cfg.preserve_alignments = cfg.compute_timestamps
if 'compute_langs' in decoding_cfg:
decoding_cfg.compute_langs = cfg.compute_langs
asr_model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type)
if hasattr(asr_model, 'cur_decoder'):
asr_model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type)
else:
asr_model.change_decoding_strategy(decoding_cfg)

# Check if ctc or rnnt model
elif hasattr(asr_model, 'joint'): # RNNT model
Expand Down
12 changes: 9 additions & 3 deletions nemo/collections/asr/parts/utils/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,9 +769,15 @@ def _get_batch_preds(self, keep_logits=False):

feat_signal, feat_signal_len = batch
feat_signal, feat_signal_len = feat_signal.to(device), feat_signal_len.to(device)
log_probs, encoded_len, predictions = self.asr_model(
processed_signal=feat_signal, processed_signal_length=feat_signal_len
)
forward_outs = self.asr_model(processed_signal=feat_signal, processed_signal_length=feat_signal_len)

if len(forward_outs) == 2: # hybrid ctc rnnt model
encoded, encoded_len = forward_outs
log_probs = self.asr_model.ctc_decoder(encoder_output=encoded)
predictions = log_probs.argmax(dim=-1, keepdim=False)
else:
log_probs, encoded_len, predictions = forward_outs

preds = torch.unbind(predictions)
for pred in preds:
self.all_preds.append(pred.cpu().numpy())
Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def get_buffered_pred_feat_rnnt(
print("Parsing manifest files...")
for l in mfst_f:
row = json.loads(l.strip())
filepaths.append(row['audio_filepath'])
audio_file = get_full_path(audio_file=row['audio_filepath'], manifest_file=manifest)
filepaths.append(audio_file)
if 'text' in row:
refs.append(row['text'])

Expand Down Expand Up @@ -149,8 +150,9 @@ def get_buffered_pred_feat(
row = json.loads(l.strip())
if 'text' in row:
refs.append(row['text'])
audio_file = get_full_path(audio_file=row['audio_filepath'], manifest_file=manifest)
# do not support partial audio
asr.read_audio_file(row['audio_filepath'], delay, model_stride_in_secs)
asr.read_audio_file(audio_file, delay, model_stride_in_secs)
hyp = asr.transcribe(tokens_per_chunk, delay)
hyps.append(hyp)

Expand Down
4 changes: 3 additions & 1 deletion tools/asr_evaluator/conf/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ engine:
mode: offline # choose from offline, chunked or offline_by_chunked
chunk_len_in_secs: 1.6 #null # Need to specify if use buffered inference (default for offline_by_chunked is 20)
total_buffer_in_secs: 4 #null # Need to specify if use buffered inference (default for offline_by_chunked is 22)
model_stride: 4 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models
model_stride: 8 # Model downsampling factor, 8 for Citrinet and FastConformer models, and 4 for Conformer models
decoder_type: null # Used for hybrid CTC RNNT model only. Specify decoder_type *ctc* or *rnnt* for hybrid CTC RNNT model.

test_ds:
manifest_filepath: null
sample_rate: 16000
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add num_workers here and also pass it in https://github.com/NVIDIA/NeMo/blob/main/tools/asr_evaluator/utils.py#L118

batch_size: 32
num_workers: 4

augmentor:
silence:
Expand Down
54 changes: 39 additions & 15 deletions tools/asr_evaluator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def run_asr_inference(cfg: DictConfig) -> DictConfig:
if (cfg.model_path and cfg.pretrained_name) or (not cfg.model_path and not cfg.pretrained_name):
raise ValueError("Please specify either cfg.model_path or cfg.pretrained_name!")

if cfg.inference.decoder_type not in [None, 'ctc', 'rnnt']:
raise ValueError("decoder_type could only be null, ctc or rnnt")

if cfg.inference.mode == "offline":
cfg = run_offline_inference(cfg)

Expand Down Expand Up @@ -67,6 +70,7 @@ def run_asr_inference(cfg: DictConfig) -> DictConfig:


def run_chunked_inference(cfg: DictConfig) -> DictConfig:

if "output_filename" not in cfg or not cfg.output_filename:
if cfg.model_path:
model_name = Path(cfg.model_path).stem
Expand All @@ -93,10 +97,28 @@ def run_chunked_inference(cfg: DictConfig) -> DictConfig:
/ "ctc"
/ "speech_to_text_buffered_infer_ctc.py"
)

if (cfg.pretrained_name and 'transducer' in cfg.pretrained_name) or (
use_rnnt_scrpit = False
# hybrid
if (cfg.pretrained_name and 'hybrid' in cfg.pretrained_name) or (cfg.model_path and 'hybrid' in cfg.model_path):
if cfg.inference.decoder_type != 'ctc':
use_rnnt_scrpit = True
# rnnt
elif (cfg.pretrained_name and 'transducer' in cfg.pretrained_name) or (
cfg.model_path and 'transducer' in cfg.model_path
):
if cfg.inference.decoder_type and cfg.inference.decoder_type != 'rnnt':
raise ValueError(
f"rnnt models only support rnnt deocoding! Current decoder_type: {cfg.inference.decoder_type}! Change it to null or rnnt for rnnt models"
)
use_rnnt_scrpit = True
# ctc model
else:
if cfg.inference.decoder_type and cfg.inference.decoder_type != 'ctc':
raise ValueError(
f"ctc models only support ctc deocoding! Current decoder_type: {cfg.inference.decoder_type}! Change it to null or ctc for ctc models"
)

if use_rnnt_scrpit:
script_path = (
Path(__file__).parents[2]
/ "examples"
Expand All @@ -106,20 +128,21 @@ def run_chunked_inference(cfg: DictConfig) -> DictConfig:
/ "speech_to_text_buffered_infer_rnnt.py"
)

base_cmd = f"python {script_path} \
calculate_wer=False \
model_path={cfg.model_path} \
pretrained_name={cfg.pretrained_name} \
dataset_manifest={cfg.test_ds.manifest_filepath} \
output_filename={cfg.output_filename} \
random_seed={cfg.random_seed} \
batch_size={cfg.test_ds.batch_size} \
num_workers={cfg.test_ds.num_workers} \
chunk_len_in_secs={cfg.inference.chunk_len_in_secs} \
total_buffer_in_secs={cfg.inference.total_buffer_in_secs} \
model_stride={cfg.inference.model_stride} "

subprocess.run(
f"python {script_path} "
f"calculate_wer=False "
f"model_path={cfg.model_path} "
f"pretrained_name={cfg.pretrained_name} "
f"dataset_manifest={cfg.test_ds.manifest_filepath} "
f"output_filename={cfg.output_filename} "
f"random_seed={cfg.random_seed} "
f"batch_size={cfg.test_ds.batch_size} "
f"chunk_len_in_secs={cfg.inference.chunk_len_in_secs} "
f"total_buffer_in_secs={cfg.inference.total_buffer_in_secs} "
f"model_stride={cfg.inference.model_stride} ",
shell=True,
check=True,
base_cmd, shell=True, check=True,
)
return cfg

Expand Down Expand Up @@ -153,6 +176,7 @@ def run_offline_inference(cfg: DictConfig) -> DictConfig:
f"dataset_manifest={cfg.test_ds.manifest_filepath} "
f"output_filename={cfg.output_filename} "
f"batch_size={cfg.test_ds.batch_size} "
f"num_workers={cfg.test_ds.num_workers} "
f"random_seed={cfg.random_seed} "
f"eval_config_yaml={f.name} "
f"decoder_type={cfg.inference.decoder_type} ",
Expand Down
Loading