Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offline and streaming inference support for hybrid model #6570

Merged
merged 20 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@

import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf
from omegaconf import OmegaConf, open_dict
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved

from nemo.collections.asr.metrics.wer import CTCDecodingConfig
from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR
from nemo.collections.asr.parts.utils.transcribe_utils import (
compute_output_filename,
Expand Down Expand Up @@ -74,10 +76,19 @@ class TranscriptionConfig:
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
random_seed: Optional[int] = None # seed number going to be used in seed_everything()

# Set to True to output greedy timestamp information (only supported models)
compute_timestamps: bool = False

# Set to True to output language ID information
compute_langs: bool = False

# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models",
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet and FasConformer models and 4 for Conformer models",

# Decoding strategy for CTC models
ctc_decoding: CTCDecodingConfig = CTCDecodingConfig()
VahidooX marked this conversation as resolved.
Show resolved Hide resolved

# Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
# device anyway, and do inference on CPU only if CUDA device is not found.
Expand All @@ -89,12 +100,18 @@ class TranscriptionConfig:
# Recompute model transcription, even if the output folder exists with scores.
overwrite_transcripts: bool = True

# decoder type for hybrid model could be None for ctc model and ctc for hybrid model
decoder_type: Optional[str] = None
VahidooX marked this conversation as resolved.
Show resolved Hide resolved


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
torch.set_grad_enabled(False)

for key in cfg:
cfg[key] = None if cfg[key] == 'None' else cfg[key]

if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

Expand All @@ -106,6 +123,11 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
if cfg.audio_dir is None and cfg.dataset_manifest is None:
raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")

if cfg.decoder_type not in [None, 'ctc']:
raise ValueError(
"decoder_type needs to be either None (ctc model) or ctc (hybrid model with ctc decoder)for speech_to_text_buffered_infer_ctc!"
)

filepaths = None
manifest = cfg.dataset_manifest
if cfg.audio_dir is not None:
Expand Down Expand Up @@ -161,6 +183,32 @@ def autocast():
)
return cfg

# Setup decoding strategy
if hasattr(asr_model, 'change_decoding_strategy'):
if not isinstance(asr_model, EncDecCTCModel) and not isinstance(asr_model, EncDecHybridRNNTCTCModel):
raise ValueError(
"The script supports ctc model and hybrid model with ctc decodng! use rnnt/speech_to_text_buffered_infer_rnnt.py for other conditions."
)

else:
if cfg.compute_langs:
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("CTC models do not support `compute_langs` at the moment.")

# ctc model
if isinstance(asr_model, EncDecCTCModel):
asr_model.change_decoding_strategy(cfg.ctc_decoding)

# hybrid ctc rnnt model with decoder_type=ctc
if isinstance(asr_model, EncDecHybridRNNTCTCModel):
if cfg.decoder_type != "ctc":
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"If the model is a EncDecHybridRNNTCTCModel model, decoder_type should be specified and set to 'ctc' for this script."
)
asr_model.change_decoding_strategy(cfg.ctc_decoding, decoder_type=cfg.decoder_type)

with open_dict(cfg):
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
cfg.decoding = cfg.ctc_decoding

asr_model.eval()
asr_model = asr_model.to(asr_model.device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import torch
from omegaconf import OmegaConf, open_dict

from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel
from nemo.collections.asr.parts.utils.streaming_utils import (
BatchedFrameASRRNNT,
LongestCommonSubsequenceBatchedFrameASRRNNT,
Expand Down Expand Up @@ -98,10 +99,19 @@ class TranscriptionConfig:
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
random_seed: Optional[int] = None # seed number going to be used in seed_everything()

# Set to True to output greedy timestamp information (only supported models)
compute_timestamps: bool = False

# Set to True to output language ID information
compute_langs: bool = False

# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models",
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet and FastConformer models and 4 for Conformer models",

# # Decoding strategy for RNNT models
# rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig()

# Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
# device anyway, and do inference on CPU only if CUDA device is not found.
Expand All @@ -112,6 +122,9 @@ class TranscriptionConfig:
# Recompute model transcription, even if the output folder exists with scores.
overwrite_transcripts: bool = True

# Decoder type for hybrid model could be None for ctc model and ctc for hybrid model
decoder_type: Optional[str] = None
VahidooX marked this conversation as resolved.
Show resolved Hide resolved

# Decoding configs
max_steps_per_timestep: int = 5 #'Maximum number of tokens decoded per acoustic timestep'
stateful_decoding: bool = False # Whether to perform stateful decoding
Expand All @@ -126,6 +139,9 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
torch.set_grad_enabled(False)

for key in cfg:
cfg[key] = None if cfg[key] == 'None' else cfg[key]

if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

Expand All @@ -143,6 +159,11 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True))
manifest = None # ignore dataset_manifest if audio_dir and dataset_manifest both presents

if cfg.decoder_type not in [None, 'rnnt']:
raise ValueError(
"decoder_type needs to be either None (rnnt model) or ctc (hybrid model with rnnt decoder)for speech_to_text_buffered_infer_rnnt!"
)

# setup GPU
if cfg.cuda is None:
if torch.cuda.is_available():
Expand Down Expand Up @@ -194,8 +215,29 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
decoding_cfg.strategy = "greedy_batch"
decoding_cfg.preserve_alignments = True # required to compute the middle token for transducers.
decoding_cfg.fused_batch_size = -1 # temporarily stop fused batch during inference.

asr_model.change_decoding_strategy(decoding_cfg)
decoding_cfg.beam.return_best_hypothesis = True # return and write the best hypothsis only
VahidooX marked this conversation as resolved.
Show resolved Hide resolved

# Setup decoding strategy
if hasattr(asr_model, 'change_decoding_strategy'):
if not isinstance(asr_model, EncDecRNNTModel) and not isinstance(asr_model, EncDecHybridRNNTCTCModel):
raise ValueError(
"The script supports rnnt model and hybrid model with rnnt decodng! use ctc/speech_to_text_buffered_infer_ctc.py for other conditions."
)
else:
# rnnt model
if isinstance(asr_model, EncDecRNNTModel):
asr_model.change_decoding_strategy(decoding_cfg)

# hybrid ctc rnnt model with decoder_type = rnnt
if isinstance(asr_model, EncDecHybridRNNTCTCModel):
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
if not cfg.decoder_type or cfg.decoder_type != "rnnt":
raise ValueError(
"If the model is a EncDecHybridRNNTCTCModel model, decoder_type should either be null or set to 'rnnt' for this script."
)
asr_model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type)

with open_dict(cfg):
cfg.decoding = decoding_cfg

feature_stride = model_cfg.preprocessor['window_stride']
model_stride_in_secs = feature_stride * cfg.model_stride
Expand Down
18 changes: 16 additions & 2 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig
from nemo.collections.asr.metrics.wer import CTCDecodingConfig
from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel
from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel, EncDecRNNTModel
Fixed Show fixed Hide fixed
from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig
from nemo.collections.asr.parts.utils.transcribe_utils import (
compute_output_filename,
Expand Down Expand Up @@ -213,6 +213,17 @@
compute_timestamps = cfg.compute_timestamps
compute_langs = cfg.compute_langs

# Check whether model and decoder type match
if isinstance(asr_model, EncDecCTCModel):
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
if cfg.decoder_type and cfg.decoder_type != 'ctc':
raise ValueError('CTC model only support ctc decoding!')
elif isinstance(asr_model, EncDecHybridRNNTCTCModel):
if cfg.decoder_type and cfg.decoder_type not in ['ctc', 'rnnt']:
raise ValueError('Hybrid model only support ctc or rnnt decoding!')
else: # rnnt model, there could be other models needs to be addressed.
if cfg.decoder_type and cfg.decoder_type != 'rnnt':
raise ValueError('RNNT model only support rnnt decoding!')

# Setup decoding strategy
if hasattr(asr_model, 'change_decoding_strategy'):
if cfg.decoder_type is not None:
Expand All @@ -226,7 +237,10 @@
decoding_cfg.preserve_alignments = cfg.compute_timestamps
if 'compute_langs' in decoding_cfg:
decoding_cfg.compute_langs = cfg.compute_langs
asr_model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type)
if isinstance(asr_model, EncDecHybridRNNTCTCModel):
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
asr_model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type)
else:
asr_model.change_decoding_strategy(decoding_cfg)

# Check if ctc or rnnt model
elif hasattr(asr_model, 'joint'): # RNNT model
Expand Down
12 changes: 9 additions & 3 deletions nemo/collections/asr/parts/utils/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,9 +769,15 @@ def _get_batch_preds(self, keep_logits=False):

feat_signal, feat_signal_len = batch
feat_signal, feat_signal_len = feat_signal.to(device), feat_signal_len.to(device)
log_probs, encoded_len, predictions = self.asr_model(
processed_signal=feat_signal, processed_signal_length=feat_signal_len
)
forward_outs = self.asr_model(processed_signal=feat_signal, processed_signal_length=feat_signal_len)

if len(forward_outs) == 2: # hybrid ctc rnnt model
encoded, encoded_len = forward_outs
log_probs = self.asr_model.ctc_decoder(encoder_output=encoded)
predictions = log_probs.argmax(dim=-1, keepdim=False)
else:
log_probs, encoded_len, predictions = forward_outs

preds = torch.unbind(predictions)
for pred in preds:
self.all_preds.append(pred.cpu().numpy())
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def write_transcription(
if compute_langs:
item['pred_lang'] = transcription.langs
item['pred_lang_chars'] = transcription.langs_chars
if not cfg.ctc_decoding.beam.return_best_hypothesis:
if not cfg.decoding.beam.return_best_hypothesis:
item['beams'] = beams[idx]
f.write(json.dumps(item) + "\n")
else:
Expand All @@ -344,7 +344,7 @@ def write_transcription(
item['pred_lang'] = best_hyps[idx].langs
item['pred_lang_chars'] = best_hyps[idx].langs_chars

if not cfg.ctc_decoding.beam.return_best_hypothesis:
if not cfg.decoding.beam.return_best_hypothesis:
item['beams'] = beams[idx]
f.write(json.dumps(item) + "\n")

Expand Down
5 changes: 3 additions & 2 deletions tools/asr_evaluator/conf/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ engine:
mode: offline # choose from offline, chunked or offline_by_chunked
chunk_len_in_secs: 1.6 #null # Need to specify if use buffered inference (default for offline_by_chunked is 20)
total_buffer_in_secs: 4 #null # Need to specify if use buffered inference (default for offline_by_chunked is 22)
model_stride: 4 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models
model_stride: 8 # Model downsampling factor, 8 for Citrinet and FastConformer models, and 4 for Conformer models
decoder_type: null # Used for hybrid CTC RNNT model only. Specify decoder_type *ctc* or *rnnt* for hybrid CTC RNNT model.
test_ds:

test_ds:
manifest_filepath: null
sample_rate: 16000
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add num_workers here and also pass it in https://github.com/NVIDIA/NeMo/blob/main/tools/asr_evaluator/utils.py#L118

batch_size: 32
Expand Down
10 changes: 7 additions & 3 deletions tools/asr_evaluator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,11 @@ def run_chunked_inference(cfg: DictConfig) -> DictConfig:
/ "speech_to_text_buffered_infer_ctc.py"
)

if (cfg.pretrained_name and 'transducer' in cfg.pretrained_name) or (
cfg.model_path and 'transducer' in cfg.model_path
if (
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
(cfg.pretrained_name and 'transducer' in cfg.pretrained_name)
or (cfg.model_path and 'transducer' in cfg.model_path)
or (cfg.pretrained_name and 'hybrid' in cfg.pretrained_name and cfg.inference.decoder_type != 'ctc')
or (cfg.model_path and 'hybrid' in cfg.model_path and cfg.inference.decoder_type != 'ctc')
):
script_path = (
Path(__file__).parents[2]
Expand All @@ -118,7 +121,8 @@ def run_chunked_inference(cfg: DictConfig) -> DictConfig:
f"batch_size={cfg.test_ds.batch_size} "
f"chunk_len_in_secs={cfg.inference.chunk_len_in_secs} "
f"total_buffer_in_secs={cfg.inference.total_buffer_in_secs} "
f"model_stride={cfg.inference.model_stride} ",
f"model_stride={cfg.inference.model_stride} "
f"decoder_type={cfg.inference.decoder_type} ",
shell=True,
check=True,
)
Expand Down
Loading