Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ngram #6063

Merged
merged 11 commits into from
Mar 1, 2023
24 changes: 20 additions & 4 deletions examples/asr/speech_to_text_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from omegaconf import MISSING, OmegaConf, open_dict

from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.parts.utils.transcribe_utils import preprocess, separate_punctuation
from nemo.core.config import hydra_runner
from nemo.utils import logging

Expand All @@ -79,6 +80,8 @@ class EvaluationConfig(transcribe_speech.TranscriptionConfig):
tolerance: Optional[float] = None

only_score_manifest: bool = False
do_lowercase: bool = False
rm_punctuation: bool = False


@hydra_runner(config_name="EvaluationConfig", schema=EvaluationConfig)
Expand Down Expand Up @@ -122,8 +125,12 @@ def main(cfg: EvaluationConfig):
invalid_manifest = True
break

ground_truth_text.append(data['text'])
predicted_text.append(data['pred_text'])
target_text = separate_punctuation(data['text'])
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
target_text = preprocess(target_text, cfg.do_lowercase, cfg.rm_punctuation)
ground_truth_text.append(target_text)

pred_text = preprocess(data['pred_text'], cfg.do_lowercase, cfg.rm_punctuation)
predicted_text.append(pred_text)

# Test for invalid manifest supplied
if invalid_manifest:
Expand All @@ -133,8 +140,15 @@ def main(cfg: EvaluationConfig):
)

# Compute the WER
metric_name = 'CER' if cfg.use_cer else 'WER'
metric_value = word_error_rate(hypotheses=predicted_text, references=ground_truth_text, use_cer=cfg.use_cer)
cer = word_error_rate(hypotheses=predicted_text, references=ground_truth_text, use_cer=True)
wer = word_error_rate(hypotheses=predicted_text, references=ground_truth_text, use_cer=False)

if cfg.use_cer:
metric_name = 'CER'
metric_value = cer
else:
metric_name = 'WER'
metric_value = wer

if cfg.tolerance is not None:
if metric_value > cfg.tolerance:
Expand All @@ -144,6 +158,8 @@ def main(cfg: EvaluationConfig):
else:
logging.info(f'Got {metric_name} of {metric_value}')

logging.info(f'WER/CER ' + str(round(100 * wer, 2)) + "%/" + str(round(100 * cer, 2)) + "%")
titu1994 marked this conversation as resolved.
Show resolved Hide resolved

# Inject the metric name and score into the config, and return the entire config
with open_dict(cfg):
cfg.metric_name = metric_name
Expand Down
39 changes: 38 additions & 1 deletion examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import contextlib
import math
import os
from dataclasses import dataclass, is_dataclass
from typing import Optional
Expand All @@ -25,6 +26,7 @@
from nemo.collections.asr.metrics.wer import CTCDecodingConfig
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig
from nemo.collections.asr.parts.submodules import ctc_beam_decoding
from nemo.collections.asr.parts.utils.transcribe_utils import (
compute_output_filename,
prepare_audio_data,
Expand All @@ -36,7 +38,6 @@
from nemo.core.config import hydra_runner
from nemo.utils import logging


"""
Transcribe audio file on a single CPU/GPU. Useful for transcription of moderate amounts of audio data.

Expand Down Expand Up @@ -113,6 +114,22 @@ class TranscriptionConfig:
# General configs
output_filename: Optional[str] = None
batch_size: int = 32

beam_strategy: str = "beam" # greedy, beam, flashlight, pyctcdecode
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
lm_path: Optional[str] = None # Path to a KenLM model file
beam_width: int = 30 # for beam
beam_alpha: float = 0.5 # for beam
beam_beta: float = 0.3 # for beam
return_best_hypothesis: bool = True # for beam

# flashlight
lexicon_path: Optional[str] = None
beam_size_token: int = 16
beam_threshold: float = 20.0
unk_weight: float = -math.inf
sil_weight: float = 0.0
unit_lm: bool = True

num_workers: int = 0
append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions.
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
Expand Down Expand Up @@ -227,6 +244,26 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
raise ValueError("CTC models do not support `compute_langs` at the moment.")
cfg.ctc_decoding.compute_timestamps = cfg.compute_timestamps

cfg.ctc_decoding.strategy = cfg.beam_strategy # beam flashlight, pyctcdecode
cfg.ctc_decoding.beam = ctc_beam_decoding.BeamCTCInferConfig(
beam_size=cfg.beam_width,
preserve_alignments=False,
compute_timestamps=cfg.compute_timestamps,
return_best_hypothesis=cfg.return_best_hypothesis,
beam_alpha=cfg.beam_alpha,
beam_beta=cfg.beam_alpha,
kenlm_path=cfg.lm_path,
)
if cfg.beam_strategy == "flashlight":
cfg.ctc_decoding.beam.flashlight_cfg = ctc_beam_decoding.FlashlightConfig(
lexicon_path=cfg.lexicon_path,
beam_size_token=cfg.beam_size_token,
beam_threshold=cfg.beam_threshold,
unk_weight=cfg.unk_weight,
sil_weight=cfg.sil_weight,
unit_lm=cfg.unit_lm,
)

asr_model.change_decoding_strategy(cfg.ctc_decoding)

# prepare audio filepaths and decide wether it's partical audio
Expand Down
49 changes: 42 additions & 7 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import glob
import json
import os
import re
from pathlib import Path
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -271,7 +272,7 @@ def normalize_timestamp_output(timestamps: dict):


def write_transcription(
transcriptions: List[rnnt_utils.Hypothesis],
transcriptions: Union[List[rnnt_utils.Hypothesis], List[List[rnnt_utils.Hypothesis]]],
cfg: DictConfig,
model_name: str,
filepaths: List[str] = None,
Expand All @@ -289,9 +290,22 @@ def write_transcription(
else:
pred_text_attr_name = 'pred_text'

if isinstance(transcriptions[0], rnnt_utils.Hypothesis): # List[rnnt_utils.Hypothesis]
best_hyps = transcriptions
assert cfg.return_best_hypothesis, "Works only for NBestHypothesis"
elif isinstance(transcriptions[0], list): # List[List[rnnt_utils.Hypothesis]]
best_hyps, beams = [], []
for hyps in transcriptions:
best_hyps.append(hyps[0])
if not cfg.return_best_hypothesis:
beam = []
for hyp in hyps:
beam.append((hyp.text, hyp.score))
beams.append(beam)

fayejf marked this conversation as resolved.
Show resolved Hide resolved
with open(cfg.output_filename, 'w', encoding='utf-8', newline='\n') as f:
if cfg.audio_dir is not None:
for idx, transcription in enumerate(transcriptions): # type: rnnt_utils.Hypothesis
for idx, transcription in enumerate(best_hyps): # type: rnnt_utils.Hypothesis
item = {'audio_filepath': filepaths[idx], pred_text_attr_name: transcription.text}

if compute_timestamps:
Expand All @@ -305,16 +319,17 @@ def write_transcription(
if compute_langs:
item['pred_lang'] = transcription.langs
item['pred_lang_chars'] = transcription.langs_chars

if not cfg.return_best_hypothesis:
item['beams'] = beams[idx]
f.write(json.dumps(item) + "\n")
else:
with open(cfg.dataset_manifest, 'r', encoding='utf_8') as fr:
for idx, line in enumerate(fr):
item = json.loads(line)
item[pred_text_attr_name] = transcriptions[idx].text
item[pred_text_attr_name] = best_hyps[idx].text

if compute_timestamps:
timestamps = transcriptions[idx].timestep
timestamps = best_hyps[idx].timestep
if timestamps is not None and isinstance(timestamps, dict):
timestamps.pop(
'timestep', None
Expand All @@ -324,9 +339,11 @@ def write_transcription(
item[f'timestamps_{key}'] = values

if compute_langs:
item['pred_lang'] = transcriptions[idx].langs
item['pred_lang_chars'] = transcriptions[idx].langs_chars
item['pred_lang'] = best_hyps[idx].langs
item['pred_lang_chars'] = best_hyps[idx].langs_chars

if not cfg.return_best_hypothesis:
item['beams'] = beams[idx]
f.write(json.dumps(item) + "\n")

return cfg.output_filename
Expand Down Expand Up @@ -421,3 +438,21 @@ def transcribe_partial_audio(
asr_model.decoder.unfreeze()
logging.set_verbosity(logging_level)
return hypotheses


def separate_punctuation(line):
punctuation_marks = '.,?'
regex_separate_punctuation = fr"([{''.join(punctuation_marks)}])"
line = re.sub(regex_separate_punctuation, r' \1 ', line)
return line


def preprocess(line, do_lowercase, rm_punctuation):
line = line.replace("\n", " ").strip()
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
if do_lowercase:
line = line.lower()
if rm_punctuation:
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
line = line.replace(",", " ").strip()
line = line.replace(".", " ").strip()
line = line.replace("?", " ").strip()
return line
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@

import nemo.collections.asr as nemo_asr
from nemo.collections.asr.parts.submodules import ctc_beam_decoding
from nemo.collections.asr.parts.utils.transcribe_utils import preprocess, separate_punctuation
from nemo.core.config import hydra_runner
from nemo.utils import logging

Expand Down Expand Up @@ -109,7 +110,9 @@ class EvalBeamSearchNGramConfig:

decoding_strategy: str = "beam" # Supports only beam for now
decoding: ctc_beam_decoding.BeamCTCInferConfig = ctc_beam_decoding.BeamCTCInferConfig(beam_size=128)


do_lowercase: bool = False
rm_punctuation: bool = False

# fmt: on

Expand Down Expand Up @@ -181,13 +184,16 @@ def beam_search_eval(

for beams_idx, beams in enumerate(beams_batch):
target = target_transcripts[sample_idx + beams_idx]
target = separate_punctuation(target)
target = preprocess(target, cfg.do_lowercase, cfg.rm_punctuation)
target_split_w = target.split()
target_split_c = list(target)
words_count += len(target_split_w)
chars_count += len(target_split_c)
wer_dist_min = cer_dist_min = 10000
for candidate_idx, candidate in enumerate(beams): # type: (int, ctc_beam_decoding.rnnt_utils.Hypothesis)
pred_text = candidate.text
pred_text = preprocess(pred_text, cfg.do_lowercase, cfg.rm_punctuation)
pred_split_w = pred_text.split()
wer_dist = editdistance.eval(target_split_w, pred_split_w)
pred_split_c = list(pred_text)
Expand Down