Skip to content

Commit

Permalink
Offline and streaming inference support for hybrid model (#6570)
Browse files Browse the repository at this point in the history
* streaming buffered for hybrid + ctc

Signed-off-by: fayejf <[email protected]>

* change default model_stride in eval.yaml

Signed-off-by: fayejf <[email protected]>

* add fc model_stride

Signed-off-by: fayejf <[email protected]>

* small fix

Signed-off-by: fayejf <[email protected]>

* check whether model and decoding match

Signed-off-by: fayejf <[email protected]>

* small fix

Signed-off-by: fayejf <[email protected]>

* streaming buffered for hybrid + rnnt

Signed-off-by: fayejf <[email protected]>

* style fix

Signed-off-by: fayejf <[email protected]>

* fix yaml

Signed-off-by: fayejf <[email protected]>

* reflect comment wip

Signed-off-by: fayejf <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: fayejf <[email protected]>

* refactor and verified

Signed-off-by: fayejf <[email protected]>

* add get_full_path to buffered

Signed-off-by: fayejf <[email protected]>

* small fix

Signed-off-by: fayejf <[email protected]>

* add RNNTDecodingConfig

Signed-off-by: fayejf <[email protected]>

* model name & instruction of changing decoding

Signed-off-by: fayejf <[email protected]>

---------

Signed-off-by: fayejf <[email protected]>
Signed-off-by: fayejf <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and yaoyu-33 committed May 26, 2023
1 parent ef2019c commit 4b363ce
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from omegaconf import OmegaConf

from nemo.collections.asr.metrics.wer import CTCDecodingConfig
from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR
from nemo.collections.asr.parts.utils.transcribe_utils import (
Expand Down Expand Up @@ -78,10 +79,16 @@ class TranscriptionConfig:
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
random_seed: Optional[int] = None # seed number going to be used in seed_everything()

# Set to True to output greedy timestamp information (only supported models)
compute_timestamps: bool = False

# Set to True to output language ID information
compute_langs: bool = False

# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models",
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet and FasConformer models and 4 for Conformer models.

# Decoding strategy for CTC models
decoding: CTCDecodingConfig = CTCDecodingConfig()
Expand All @@ -108,6 +115,9 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
torch.set_grad_enabled(False)

for key in cfg:
cfg[key] = None if cfg[key] == 'None' else cfg[key]

if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

Expand Down Expand Up @@ -174,6 +184,23 @@ def autocast():
)
return cfg

# Setup decoding strategy
if hasattr(asr_model, 'change_decoding_strategy'):
if not isinstance(asr_model, EncDecCTCModel) and not isinstance(asr_model, EncDecHybridRNNTCTCModel):
raise ValueError("The script supports ctc model and hybrid model with ctc decodng!")

else:
if cfg.compute_langs:
raise ValueError("CTC models do not support `compute_langs` at the moment.")

if hasattr(
asr_model, 'cur_decoder'
): # hybrid model with ctc decoding or potential other models containing decoding switch feature
asr_model.change_decoding_strategy(cfg.decoding, decoder_type='ctc')

else: # ctc model
asr_model.change_decoding_strategy(cfg.decoding)

asr_model.eval()
asr_model = asr_model.to(asr_model.device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf, open_dict

from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig
from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.collections.asr.parts.utils.streaming_utils import (
BatchedFrameASRRNNT,
Expand Down Expand Up @@ -101,10 +102,16 @@ class TranscriptionConfig:
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
random_seed: Optional[int] = None # seed number going to be used in seed_everything()

# Set to True to output greedy timestamp information (only supported models)
compute_timestamps: bool = False

# Set to True to output language ID information
compute_langs: bool = False

# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet and FastConformer models and 4 for Conformer models.

# Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
# device anyway, and do inference on CPU only if CUDA device is not found.
Expand All @@ -115,6 +122,9 @@ class TranscriptionConfig:
# Recompute model transcription, even if the output folder exists with scores.
overwrite_transcripts: bool = True

# Decoding strategy for RNNT models
decoding: RNNTDecodingConfig = RNNTDecodingConfig()

# Decoding configs
max_steps_per_timestep: int = 5 #'Maximum number of tokens decoded per acoustic timestep'
stateful_decoding: bool = False # Whether to perform stateful decoding
Expand All @@ -135,6 +145,9 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
torch.set_grad_enabled(False)

for key in cfg:
cfg[key] = None if cfg[key] == 'None' else cfg[key]

if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

Expand Down Expand Up @@ -195,20 +208,27 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
asr_model = asr_model.to(asr_model.device)

# Change Decoding Config
decoding_cfg = asr_model.cfg.decoding
with open_dict(decoding_cfg):
with open_dict(cfg.decoding):
if cfg.stateful_decoding:
decoding_cfg.strategy = "greedy"
cfg.decoding.strategy = "greedy"
else:
decoding_cfg.strategy = "greedy_batch"
decoding_cfg.preserve_alignments = True # required to compute the middle token for transducers.
decoding_cfg.fused_batch_size = -1 # temporarily stop fused batch during inference.
decoding_cfg.beam.return_best_hypothesis = True

asr_model.change_decoding_strategy(decoding_cfg)
cfg.decoding.strategy = "greedy_batch"
cfg.decoding.preserve_alignments = True # required to compute the middle token for transducers.
cfg.decoding.fused_batch_size = -1 # temporarily stop fused batch during inference.
cfg.decoding.beam.return_best_hypothesis = True # return and write the best hypothsis only

# Setup decoding strategy
if hasattr(asr_model, 'change_decoding_strategy'):
if not isinstance(asr_model, EncDecRNNTModel) and not isinstance(asr_model, EncDecHybridRNNTCTCModel):
raise ValueError("The script supports rnnt model and hybrid model with rnnt decodng!")
else:
# rnnt model
if isinstance(asr_model, EncDecRNNTModel):
asr_model.change_decoding_strategy(cfg.decoding)

with open_dict(cfg):
cfg.decoding = decoding_cfg
# hybrid ctc rnnt model with decoder_type = rnnt
if hasattr(asr_model, 'cur_decoder'):
asr_model.change_decoding_strategy(cfg.decoding, decoder_type='rnnt')

feature_stride = model_cfg.preprocessor['window_stride']
model_stride_in_secs = feature_stride * cfg.model_stride
Expand Down
16 changes: 15 additions & 1 deletion examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,17 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
compute_timestamps = cfg.compute_timestamps
compute_langs = cfg.compute_langs

# Check whether model and decoder type match
if isinstance(asr_model, EncDecCTCModel):
if cfg.decoder_type and cfg.decoder_type != 'ctc':
raise ValueError('CTC model only support ctc decoding!')
elif isinstance(asr_model, EncDecHybridRNNTCTCModel):
if cfg.decoder_type and cfg.decoder_type not in ['ctc', 'rnnt']:
raise ValueError('Hybrid model only support ctc or rnnt decoding!')
else: # rnnt model, there could be other models needs to be addressed.
if cfg.decoder_type and cfg.decoder_type != 'rnnt':
raise ValueError('RNNT model only support rnnt decoding!')

# Setup decoding strategy
if hasattr(asr_model, 'change_decoding_strategy'):
if cfg.decoder_type is not None:
Expand All @@ -240,7 +251,10 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
decoding_cfg.preserve_alignments = cfg.compute_timestamps
if 'compute_langs' in decoding_cfg:
decoding_cfg.compute_langs = cfg.compute_langs
asr_model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type)
if hasattr(asr_model, 'cur_decoder'):
asr_model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type)
else:
asr_model.change_decoding_strategy(decoding_cfg)

# Check if ctc or rnnt model
elif hasattr(asr_model, 'joint'): # RNNT model
Expand Down
12 changes: 9 additions & 3 deletions nemo/collections/asr/parts/utils/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,9 +769,15 @@ def _get_batch_preds(self, keep_logits=False):

feat_signal, feat_signal_len = batch
feat_signal, feat_signal_len = feat_signal.to(device), feat_signal_len.to(device)
log_probs, encoded_len, predictions = self.asr_model(
processed_signal=feat_signal, processed_signal_length=feat_signal_len
)
forward_outs = self.asr_model(processed_signal=feat_signal, processed_signal_length=feat_signal_len)

if len(forward_outs) == 2: # hybrid ctc rnnt model
encoded, encoded_len = forward_outs
log_probs = self.asr_model.ctc_decoder(encoder_output=encoded)
predictions = log_probs.argmax(dim=-1, keepdim=False)
else:
log_probs, encoded_len, predictions = forward_outs

preds = torch.unbind(predictions)
for pred in preds:
self.all_preds.append(pred.cpu().numpy())
Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def get_buffered_pred_feat_rnnt(
print("Parsing manifest files...")
for l in mfst_f:
row = json.loads(l.strip())
filepaths.append(row['audio_filepath'])
audio_file = get_full_path(audio_file=row['audio_filepath'], manifest_file=manifest)
filepaths.append(audio_file)
if 'text' in row:
refs.append(row['text'])

Expand Down Expand Up @@ -149,8 +150,9 @@ def get_buffered_pred_feat(
row = json.loads(l.strip())
if 'text' in row:
refs.append(row['text'])
audio_file = get_full_path(audio_file=row['audio_filepath'], manifest_file=manifest)
# do not support partial audio
asr.read_audio_file(row['audio_filepath'], delay, model_stride_in_secs)
asr.read_audio_file(audio_file, delay, model_stride_in_secs)
hyp = asr.transcribe(tokens_per_chunk, delay)
hyps.append(hyp)

Expand Down
4 changes: 3 additions & 1 deletion tools/asr_evaluator/conf/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ engine:
mode: offline # choose from offline, chunked or offline_by_chunked
chunk_len_in_secs: 1.6 #null # Need to specify if use buffered inference (default for offline_by_chunked is 20)
total_buffer_in_secs: 4 #null # Need to specify if use buffered inference (default for offline_by_chunked is 22)
model_stride: 4 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models
model_stride: 8 # Model downsampling factor, 8 for Citrinet and FastConformer models, and 4 for Conformer models
decoder_type: null # Used for hybrid CTC RNNT model only. Specify decoder_type *ctc* or *rnnt* for hybrid CTC RNNT model.

test_ds:
manifest_filepath: null
sample_rate: 16000
batch_size: 32
num_workers: 4

augmentor:
silence:
Expand Down
75 changes: 59 additions & 16 deletions tools/asr_evaluator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def run_asr_inference(cfg: DictConfig) -> DictConfig:
if (cfg.model_path and cfg.pretrained_name) or (not cfg.model_path and not cfg.pretrained_name):
raise ValueError("Please specify either cfg.model_path or cfg.pretrained_name!")

if cfg.inference.decoder_type not in [None, 'ctc', 'rnnt']:
raise ValueError("decoder_type could only be null, ctc or rnnt")

if cfg.inference.mode == "offline":
cfg = run_offline_inference(cfg)

Expand Down Expand Up @@ -67,6 +70,7 @@ def run_asr_inference(cfg: DictConfig) -> DictConfig:


def run_chunked_inference(cfg: DictConfig) -> DictConfig:

if "output_filename" not in cfg or not cfg.output_filename:
if cfg.model_path:
model_name = Path(cfg.model_path).stem
Expand All @@ -93,10 +97,43 @@ def run_chunked_inference(cfg: DictConfig) -> DictConfig:
/ "ctc"
/ "speech_to_text_buffered_infer_ctc.py"
)
use_rnnt_scrpit = False
# hybrid model
if (cfg.pretrained_name and 'hybrid' in cfg.pretrained_name.lower()) or (
cfg.model_path and 'hybrid' in cfg.model_path.lower()
):
if cfg.inference.decoder_type != 'ctc':
use_rnnt_scrpit = True
# rnnt model
elif (
(cfg.pretrained_name and 'rnnt' in cfg.pretrained_name.lower())
or (cfg.pretrained_name and 'transducer' in cfg.pretrained_name.lower())
or (cfg.model_path and 'rnnt' in cfg.model_path.lower())
or (cfg.model_path and 'transducer' in cfg.model_path.lower())
):
if cfg.inference.decoder_type and cfg.inference.decoder_type != 'rnnt':
raise ValueError(
f"rnnt models only support rnnt deocoding! Current decoder_type: {cfg.inference.decoder_type}! Change it to null or rnnt for rnnt models"
)
use_rnnt_scrpit = True

if (cfg.pretrained_name and 'transducer' in cfg.pretrained_name) or (
cfg.model_path and 'transducer' in cfg.model_path
# ctc model
elif (cfg.pretrained_name and 'ctc' in cfg.pretrained_name.lower()) or (
cfg.pretrained_name and 'ctc' in cfg.pretrained_name.lower()
):
if cfg.inference.decoder_type and cfg.inference.decoder_type != 'ctc':
raise ValueError(
f"ctc models only support ctc deocoding! Current decoder_type: {cfg.inference.decoder_type}! Change it to null or ctc for ctc models"
)
else:
raise ValueError(
"Please make sure your pretrained_name or model_path contains \n\
'hybrid' for EncDecHybridRNNTCTCModel model, \n\
'transducer/rnnt' for EncDecRNNTModel model or \n\
'ctc' for EncDecCTCModel."
)

if use_rnnt_scrpit:
script_path = (
Path(__file__).parents[2]
/ "examples"
Expand All @@ -106,20 +143,25 @@ def run_chunked_inference(cfg: DictConfig) -> DictConfig:
/ "speech_to_text_buffered_infer_rnnt.py"
)

# If need to change other config such as decoding strategy, could either:
# 1) change TranscriptionConfig on top of the executed scripts such as speech_to_text_buffered_infer_rnnt.py, or
# 2) add command as "decoding.strategy=greedy_batch " to below script

base_cmd = f"python {script_path} \
calculate_wer=False \
model_path={cfg.model_path} \
pretrained_name={cfg.pretrained_name} \
dataset_manifest={cfg.test_ds.manifest_filepath} \
output_filename={cfg.output_filename} \
random_seed={cfg.random_seed} \
batch_size={cfg.test_ds.batch_size} \
num_workers={cfg.test_ds.num_workers} \
chunk_len_in_secs={cfg.inference.chunk_len_in_secs} \
total_buffer_in_secs={cfg.inference.total_buffer_in_secs} \
model_stride={cfg.inference.model_stride} "

subprocess.run(
f"python {script_path} "
f"calculate_wer=False "
f"model_path={cfg.model_path} "
f"pretrained_name={cfg.pretrained_name} "
f"dataset_manifest={cfg.test_ds.manifest_filepath} "
f"output_filename={cfg.output_filename} "
f"random_seed={cfg.random_seed} "
f"batch_size={cfg.test_ds.batch_size} "
f"chunk_len_in_secs={cfg.inference.chunk_len_in_secs} "
f"total_buffer_in_secs={cfg.inference.total_buffer_in_secs} "
f"model_stride={cfg.inference.model_stride} ",
shell=True,
check=True,
base_cmd, shell=True, check=True,
)
return cfg

Expand All @@ -142,7 +184,7 @@ def run_offline_inference(cfg: DictConfig) -> DictConfig:
f.seek(0) # reset file pointer
script_path = Path(__file__).parents[2] / "examples" / "asr" / "transcribe_speech.py"

# If need to move other config such as decoding strategy, could either:
# If need to change other config such as decoding strategy, could either:
# 1) change TranscriptionConfig on top of the executed scripts such as transcribe_speech.py in examples/asr, or
# 2) add command as "rnnt_decoding.strategy=greedy_batch " to below script
subprocess.run(
Expand All @@ -153,6 +195,7 @@ def run_offline_inference(cfg: DictConfig) -> DictConfig:
f"dataset_manifest={cfg.test_ds.manifest_filepath} "
f"output_filename={cfg.output_filename} "
f"batch_size={cfg.test_ds.batch_size} "
f"num_workers={cfg.test_ds.num_workers} "
f"random_seed={cfg.random_seed} "
f"eval_config_yaml={f.name} "
f"decoder_type={cfg.inference.decoder_type} ",
Expand Down

0 comments on commit 4b363ce

Please sign in to comment.