Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AST eval #8112

Merged
merged 8 commits into from
Jan 10, 2024
4 changes: 4 additions & 0 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ class TranscriptionConfig:
# Set to False to return text instead of hypotheses from the transcribe function, so as to save memory
return_hypotheses: bool = True

# key for groundtruth text in manifest
gt_text_attr_name: str = "text"


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]:
Expand Down Expand Up @@ -370,6 +373,7 @@ def autocast(dtype=None):
if cfg.calculate_wer:
output_manifest_w_wer, total_res, _ = cal_write_wer(
pred_manifest=output_filename,
gt_text_attr_name=cfg.gt_text_attr_name,
pred_text_attr_name=pred_text_attr_name,
clean_groundtruth_text=cfg.clean_groundtruth_text,
langid=cfg.langid,
Expand Down
36 changes: 24 additions & 12 deletions nemo/collections/asr/parts/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@

def cal_write_wer(
pred_manifest: str = None,
gt_text_attr_name: str = "text",
pred_text_attr_name: str = "pred_text",
clean_groundtruth_text: bool = False,
langid: str = 'en',
Expand All @@ -128,14 +129,17 @@
for line in fp:
sample = json.loads(line)

if 'text' not in sample:
logging.info(
"ground-truth text is not present in manifest! Cannot calculate Word Error Rate. Returning!"
)
if gt_text_attr_name not in sample:
if "text" in sample:
gt_text_attr_name = "text"
Fixed Show fixed Hide fixed

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable gt_text_attr_name is not used.
else:
logging.info(
f"ground-truth text attribute {gt_text_attr_name} is not present in manifest! Cannot calculate WER. Returning!"
)
return None, None, eval_metric

hyp = sample[pred_text_attr_name]
ref = sample['text']
hyp = sample[pred_text_attr_name].strip()
ref = sample[gt_text_attr_name].strip()

if clean_groundtruth_text:
ref = clean_label(ref, langid=langid)
Expand Down Expand Up @@ -211,13 +215,16 @@
sample = json.loads(line)

if gt_text_attr_name not in sample:
logging.info(
f"ground-truth text attribute {pred_text_attr_name} is not present in manifest! Cannot calculate {metric}. Returning!"
)
if "text" in sample:
gt_text_attr_name = "text"
Fixed Show fixed Hide fixed

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable gt_text_attr_name is not used.
else:
logging.info(
f"ground-truth text attribute {gt_text_attr_name} is not present in manifest! Cannot calculate {metric}. Returning!"
)
return None, None, metric

hyp = sample[pred_text_attr_name]
ref = sample['text']
hyp = sample[pred_text_attr_name].strip()
ref = sample[gt_text_attr_name].strip()

if ignore_punctuation:
ref = remove_punctuations(ref, punctuations=punctuations)
Expand All @@ -227,13 +234,18 @@
ref = ref.lower()
hyp = hyp.lower()

score = metric_calculator(hyp, ref).item()
if metric == 'bleu':
score = metric_calculator([hyp], [[ref]]).item()
else:
score = metric_calculator(hyp, ref).item()
sample[metric] = score # evaluatin metric, could be word error rate of character error rate

samples.append(sample)
hyps.append(hyp)
refs.append(ref)

if metric == 'bleu':
refs = [[ref] for ref in refs]
total_score = metric_calculator(hyps, refs).item()

if not output_filename:
Expand Down
2 changes: 2 additions & 0 deletions tools/asr_evaluator/asr_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def main(cfg):
if cfg.analyst.metric_calculator.get("metric", "wer") == "wer":
output_manifest_w_wer, total_res, eval_metric = cal_write_wer(
pred_manifest=cfg.engine.output_filename,
gt_text_attr_name=cfg.analyst.metric_calculator.get("gt_text_attr_name", "text"),
pred_text_attr_name=cfg.analyst.metric_calculator.get("pred_text_attr_name", "pred_text"),
clean_groundtruth_text=cfg.analyst.metric_calculator.clean_groundtruth_text,
langid=cfg.analyst.metric_calculator.langid,
use_cer=cfg.analyst.metric_calculator.use_cer,
Expand Down
Loading