Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eval_beamsearch_ngram.py with hybrid ctc #6656

Merged
merged 16 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/asr/speech_to_text_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class EvaluationConfig(transcribe_speech.TranscriptionConfig):
only_score_manifest: bool = False

text_processing: Optional[TextProcessingConfig] = TextProcessingConfig(
punctuation_marks=".,?", separate_punctuation=True, do_lowercase=False, rm_punctuation=False,
punctuation_marks=".,?", separate_punctuation=False, do_lowercase=False, rm_punctuation=False,
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
)


Expand Down Expand Up @@ -134,6 +134,7 @@ def main(cfg: EvaluationConfig):
pc = PunctuationCapitalization(cfg.text_processing.punctuation_marks)
if cfg.text_processing.separate_punctuation:
ground_truth_text = pc.separate_punctuation(ground_truth_text)
predicted_text = pc.separate_punctuation(predicted_text)
if cfg.text_processing.do_lowercase:
ground_truth_text = pc.do_lowercase(ground_truth_text)
predicted_text = pc.do_lowercase(predicted_text)
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

self.ctc_decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer)
self.decoding = self.ctc_decoding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.decoding is for RNNT.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


self.ctc_wer = WERBPE(
decoding=self.ctc_decoding,
Expand Down
12 changes: 11 additions & 1 deletion nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def transcribe(
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
verbose: bool = True,
logprobs: bool = False,
) -> (List[str], Optional[List['Hypothesis']]):
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Expand All @@ -119,6 +120,7 @@ def transcribe(
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied.
verbose: (bool) whether to display tqdm progress bar
logprobs: (bool) whether to return ctc logits insted of hypotheses

Returns:
Returns a tuple of 2 items -
Expand Down Expand Up @@ -189,6 +191,7 @@ def transcribe(
config['augmentor'] = augmentor

temporary_datalayer = self._setup_transcribe_dataloader(config)
logits_list = []
for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose):
encoded, encoded_len = self.forward(
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)
Expand All @@ -206,6 +209,9 @@ def transcribe(
best_hyp[idx].y_sequence = logits[idx][: encoded_len[idx]]
if best_hyp[idx].alignments is None:
best_hyp[idx].alignments = best_hyp[idx].y_sequence
if logprobs:
for logit, elen in zip(logits, encoded_len):
logits_list.append(logit[:elen])
del logits

hypotheses += best_hyp
Expand All @@ -229,7 +235,10 @@ def transcribe(
self.joint.unfreeze()
if hasattr(self, 'ctc_decoder'):
self.ctc_decoder.unfreeze()
return hypotheses, all_hypotheses
if logprobs:
return logits_list, None
else:
return hypotheses, all_hypotheses

def change_vocabulary(
self,
Expand Down Expand Up @@ -339,6 +348,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

self.ctc_decoding = CTCDecoding(decoding_cfg=decoding_cfg, vocabulary=self.ctc_decoder.vocabulary)
self.decoding = self.ctc_decoding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to do it here. It would overwrite the RNNT decoding and we cannot switch back. There are other places in the code which are not based on this assumption.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


self.ctc_wer = WER(
decoding=self.ctc_decoding,
Expand Down
43 changes: 33 additions & 10 deletions scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

"""
# This script would evaluate an N-gram language model trained with KenLM library (https://github.com/kpu/kenlm) in
# fusion with beam search decoders on top of a trained ASR model. NeMo's beam search decoders are capable of using the
# KenLM's N-gram models to find the best candidates. This script supports both character level and BPE level
# fusion with beam search decoders on top of a trained ASR model with CTC decoder. To evaluate a model with
# Transducer (RNN-T) decoder use another script 'scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer.py'.
# NeMo's beam search decoders are capable of using the KenLM's N-gram models
# to find the best candidates. This script supports both character level and BPE level
# encodings and models which is detected automatically from the type of the model.
# You may train the LM model with 'scripts/ngram_lm/train_kenlm.py'.
# You may train the LM model with 'scripts/asr_language_modeling/ngram_lm/train_kenlm.py'.

# Config Help

Expand All @@ -29,7 +31,7 @@
# USAGE

python eval_beamsearch_ngram.py nemo_model_file=<path to the .nemo file of the model> \
input_manifest=<path to the evaluation JSON manifest file \
input_manifest=<path to the evaluation JSON manifest file> \
kenlm_model_file=<path to the binary KenLM model> \
beam_width=[<list of the beam widths, separated with commas>] \
beam_alpha=[<list of the beam alphas, separated with commas>] \
Expand Down Expand Up @@ -70,6 +72,7 @@
from tqdm.auto import tqdm

import nemo.collections.asr as nemo_asr
from nemo.collections.asr.models import EncDecHybridRNNTCTCModel
from nemo.collections.asr.parts.submodules import ctc_beam_decoding
from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization, TextProcessingConfig
from nemo.core.config import hydra_runner
Expand Down Expand Up @@ -113,7 +116,7 @@ class EvalBeamSearchNGramConfig:

text_processing: Optional[TextProcessingConfig] = TextProcessingConfig(
punctuation_marks = ".,?",
separate_punctuation = True,
separate_punctuation = False,
do_lowercase = False,
rm_punctuation = False,
)
Expand Down Expand Up @@ -151,7 +154,10 @@ def beam_search_eval(
model.cfg.decoding.beam = cfg.decoding

# Update model's decoding strategy
model.change_decoding_strategy(model.cfg.decoding)
if isinstance(model, EncDecHybridRNNTCTCModel):
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
model.change_decoding_strategy(model.cfg.decoding, decoder_type='ctc')
else:
model.change_decoding_strategy(model.cfg.decoding)
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
logging.setLevel(level)

wer_dist_first = cer_dist_first = 0
Expand Down Expand Up @@ -199,6 +205,8 @@ def beam_search_eval(
pred_text = punctuation_capitalization.do_lowercase([pred_text])[0]
if cfg.text_processing.rm_punctuation:
pred_text = punctuation_capitalization.rm_punctuation([pred_text])[0]
if cfg.text_processing.separate_punctuation:
pred_text = punctuation_capitalization.separate_punctuation([pred_text])[0]
pred_split_w = pred_text.split()
wer_dist = editdistance.eval(target_split_w, pred_split_w)
pred_split_c = list(pred_text)
Expand Down Expand Up @@ -247,6 +255,7 @@ def beam_search_eval(

@hydra_runner(config_path=None, config_name='EvalBeamSearchNGramConfig', schema=EvalBeamSearchNGramConfig)
def main(cfg: EvalBeamSearchNGramConfig):
logging.warning("This file will be renamed to eval_beamsearch_ngram_ctc.py in the future NeMo (1.21) release.")
if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg) # type: EvalBeamSearchNGramConfig

Expand Down Expand Up @@ -279,12 +288,12 @@ def main(cfg: EvalBeamSearchNGramConfig):
audio_file_paths.append(str(audio_file.absolute()))

punctuation_capitalization = PunctuationCapitalization(cfg.text_processing.punctuation_marks)
if cfg.text_processing.separate_punctuation:
target_transcripts = punctuation_capitalization.separate_punctuation(target_transcripts)
if cfg.text_processing.do_lowercase:
target_transcripts = punctuation_capitalization.do_lowercase(target_transcripts)
if cfg.text_processing.rm_punctuation:
target_transcripts = punctuation_capitalization.rm_punctuation(target_transcripts)
if cfg.text_processing.separate_punctuation:
target_transcripts = punctuation_capitalization.separate_punctuation(target_transcripts)

if cfg.probs_cache_file and os.path.exists(cfg.probs_cache_file):
logging.info(f"Found a pickle file of probabilities at '{cfg.probs_cache_file}'.")
Expand Down Expand Up @@ -316,7 +325,15 @@ def default_autocast():

with autocast():
with torch.no_grad():
all_logits = asr_model.transcribe(audio_file_paths, batch_size=cfg.acoustic_batch_size, logprobs=True)
if isinstance(asr_model, EncDecHybridRNNTCTCModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The EncDecHybridRNNTCTCModel and regular ctc models should return the same number of outputs to be consistent.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree, but need to be discussed with Som

asr_model.cur_decoder = 'ctc'
all_logits, _ = asr_model.transcribe(
audio_file_paths, batch_size=cfg.acoustic_batch_size, logprobs=True
)
else:
all_logits = asr_model.transcribe(
audio_file_paths, batch_size=cfg.acoustic_batch_size, logprobs=True
)

all_probs = all_logits
if cfg.probs_cache_file:
Expand All @@ -331,11 +348,17 @@ def default_autocast():
for batch_idx, probs in enumerate(all_probs):
preds = np.argmax(probs, axis=1)
preds_tensor = torch.tensor(preds, device='cpu').unsqueeze(0)
pred_text = asr_model._wer.decoding.ctc_decoder_predictions_tensor(preds_tensor)[0][0]
if isinstance(asr_model, EncDecHybridRNNTCTCModel):
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
pred_text = asr_model.ctc_decoding.ctc_decoder_predictions_tensor(preds_tensor)[0][0]
else:
pred_text = asr_model._wer.decoding.ctc_decoder_predictions_tensor(preds_tensor)[0][0]

if cfg.text_processing.do_lowercase:
pred_text = punctuation_capitalization.do_lowercase([pred_text])[0]
if cfg.text_processing.rm_punctuation:
pred_text = punctuation_capitalization.rm_punctuation([pred_text])[0]
if cfg.text_processing.separate_punctuation:
pred_text = punctuation_capitalization.separate_punctuation([pred_text])[0]

pred_split_w = pred_text.split()
target_split_w = target_transcripts[batch_idx].split()
Expand Down
Loading