Skip to content

Commit

Permalink
Patch transcribe and support offline transcribe for hybrid model (#6550
Browse files Browse the repository at this point in the history
…) (#6559)

Signed-off-by: fayejf <[email protected]>
Co-authored-by: fayejf <[email protected]>
  • Loading branch information
2 people authored and yaoyu-33 committed May 26, 2023
1 parent 415d02b commit bccb46e
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
17 changes: 14 additions & 3 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf
from omegaconf import OmegaConf, open_dict

from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig
from nemo.collections.asr.metrics.wer import CTCDecodingConfig
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel
from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig
from nemo.collections.asr.parts.utils.transcribe_utils import (
compute_output_filename,
Expand Down Expand Up @@ -154,6 +154,9 @@ class TranscriptionConfig:
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

for key in cfg:
cfg[key] = None if cfg[key] == 'None' else cfg[key]

if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

Expand Down Expand Up @@ -223,7 +226,6 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
decoding_cfg.preserve_alignments = cfg.compute_timestamps
if 'compute_langs' in decoding_cfg:
decoding_cfg.compute_langs = cfg.compute_langs

asr_model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type)

# Check if ctc or rnnt model
Expand All @@ -243,6 +245,15 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:

asr_model.change_decoding_strategy(cfg.ctc_decoding)

# Setup decoding config based on model type and decoder_type
with open_dict(cfg):
if isinstance(asr_model, EncDecCTCModel) or (
isinstance(asr_model, EncDecHybridRNNTCTCModel) and cfg.decoder_type == "ctc"
):
cfg.decoding = cfg.ctc_decoding
else:
cfg.decoding = cfg.rnnt_decoding

# prepare audio filepaths and decide wether it's partical audio
filepaths, partial_audio = prepare_audio_data(cfg)

Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,14 +289,14 @@ def write_transcription(

if isinstance(transcriptions[0], rnnt_utils.Hypothesis): # List[rnnt_utils.Hypothesis]
best_hyps = transcriptions
assert cfg.ctc_decoding.beam.return_best_hypothesis, "Works only with return_best_hypothesis=true"
assert cfg.decoding.beam.return_best_hypothesis, "Works only with return_best_hypothesis=true"
elif isinstance(transcriptions[0], list) and isinstance(
transcriptions[0][0], rnnt_utils.Hypothesis
): # List[List[rnnt_utils.Hypothesis]] NBestHypothesis
best_hyps, beams = [], []
for hyps in transcriptions:
best_hyps.append(hyps[0])
if not cfg.ctc_decoding.beam.return_best_hypothesis:
if not cfg.decoding.beam.return_best_hypothesis:
beam = []
for hyp in hyps:
beam.append((hyp.text, hyp.score))
Expand Down
2 changes: 1 addition & 1 deletion tools/asr_evaluator/conf/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ engine:
chunk_len_in_secs: 1.6 #null # Need to specify if use buffered inference (default for offline_by_chunked is 20)
total_buffer_in_secs: 4 #null # Need to specify if use buffered inference (default for offline_by_chunked is 22)
model_stride: 4 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models

decoder_type: null # Used for hybrid CTC RNNT model only. Specify decoder_type *ctc* or *rnnt* for hybrid CTC RNNT model.
test_ds:
manifest_filepath: null
sample_rate: 16000
Expand Down
3 changes: 2 additions & 1 deletion tools/asr_evaluator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def run_offline_inference(cfg: DictConfig) -> DictConfig:
f"output_filename={cfg.output_filename} "
f"batch_size={cfg.test_ds.batch_size} "
f"random_seed={cfg.random_seed} "
f"eval_config_yaml={f.name} ",
f"eval_config_yaml={f.name} "
f"decoder_type={cfg.inference.decoder_type} ",
shell=True,
check=True,
)
Expand Down

0 comments on commit bccb46e

Please sign in to comment.