Skip to content

Commit 941282d

Browse files
stevehuang52fayejf
authored andcommitted
Fix AST eval (NVIDIA#8112)
* add text metrics to asr eval Signed-off-by: stevehuang52 <[email protected]> * temporary fix for EncDecTransfModelBPE Signed-off-by: stevehuang52 <[email protected]> * fix bleu eval Signed-off-by: stevehuang52 <[email protected]> * fix typo Signed-off-by: stevehuang52 <[email protected]> --------- Signed-off-by: stevehuang52 <[email protected]> Co-authored-by: fayejf <[email protected]> Signed-off-by: Sasha Meister <[email protected]>
1 parent c79f839 commit 941282d

File tree

3 files changed

+30
-12
lines changed

3 files changed

+30
-12
lines changed

examples/asr/transcribe_speech.py

+4
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ class TranscriptionConfig:
175175
# Set to False to return text instead of hypotheses from the transcribe function, so as to save memory
176176
return_hypotheses: bool = True
177177

178+
# key for groundtruth text in manifest
179+
gt_text_attr_name: str = "text"
180+
178181

179182
@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
180183
def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]:
@@ -370,6 +373,7 @@ def autocast(dtype=None):
370373
if cfg.calculate_wer:
371374
output_manifest_w_wer, total_res, _ = cal_write_wer(
372375
pred_manifest=output_filename,
376+
gt_text_attr_name=cfg.gt_text_attr_name,
373377
pred_text_attr_name=pred_text_attr_name,
374378
clean_groundtruth_text=cfg.clean_groundtruth_text,
375379
langid=cfg.langid,

nemo/collections/asr/parts/utils/eval_utils.py

+24-12
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def convert_num_to_words(_str: str, langid: str = "en") -> str:
106106

107107
def cal_write_wer(
108108
pred_manifest: str = None,
109+
gt_text_attr_name: str = "text",
109110
pred_text_attr_name: str = "pred_text",
110111
clean_groundtruth_text: bool = False,
111112
langid: str = 'en',
@@ -128,14 +129,17 @@ def cal_write_wer(
128129
for line in fp:
129130
sample = json.loads(line)
130131

131-
if 'text' not in sample:
132-
logging.info(
133-
"ground-truth text is not present in manifest! Cannot calculate Word Error Rate. Returning!"
134-
)
132+
if gt_text_attr_name not in sample:
133+
if "text" in sample:
134+
gt_text_attr_name = "text"
135+
else:
136+
logging.info(
137+
f"ground-truth text attribute {gt_text_attr_name} is not present in manifest! Cannot calculate WER. Returning!"
138+
)
135139
return None, None, eval_metric
136140

137-
hyp = sample[pred_text_attr_name]
138-
ref = sample['text']
141+
hyp = sample[pred_text_attr_name].strip()
142+
ref = sample[gt_text_attr_name].strip()
139143

140144
if clean_groundtruth_text:
141145
ref = clean_label(ref, langid=langid)
@@ -211,13 +215,16 @@ def cal_write_text_metric(
211215
sample = json.loads(line)
212216

213217
if gt_text_attr_name not in sample:
214-
logging.info(
215-
f"ground-truth text attribute {pred_text_attr_name} is not present in manifest! Cannot calculate {metric}. Returning!"
216-
)
218+
if "text" in sample:
219+
gt_text_attr_name = "text"
220+
else:
221+
logging.info(
222+
f"ground-truth text attribute {gt_text_attr_name} is not present in manifest! Cannot calculate {metric}. Returning!"
223+
)
217224
return None, None, metric
218225

219-
hyp = sample[pred_text_attr_name]
220-
ref = sample['text']
226+
hyp = sample[pred_text_attr_name].strip()
227+
ref = sample[gt_text_attr_name].strip()
221228

222229
if ignore_punctuation:
223230
ref = remove_punctuations(ref, punctuations=punctuations)
@@ -227,13 +234,18 @@ def cal_write_text_metric(
227234
ref = ref.lower()
228235
hyp = hyp.lower()
229236

230-
score = metric_calculator(hyp, ref).item()
237+
if metric == 'bleu':
238+
score = metric_calculator([hyp], [[ref]]).item()
239+
else:
240+
score = metric_calculator(hyp, ref).item()
231241
sample[metric] = score # evaluatin metric, could be word error rate of character error rate
232242

233243
samples.append(sample)
234244
hyps.append(hyp)
235245
refs.append(ref)
236246

247+
if metric == 'bleu':
248+
refs = [[ref] for ref in refs]
237249
total_score = metric_calculator(hyps, refs).item()
238250

239251
if not output_filename:

tools/asr_evaluator/asr_evaluator.py

+2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def main(cfg):
6666
if cfg.analyst.metric_calculator.get("metric", "wer") == "wer":
6767
output_manifest_w_wer, total_res, eval_metric = cal_write_wer(
6868
pred_manifest=cfg.engine.output_filename,
69+
gt_text_attr_name=cfg.analyst.metric_calculator.get("gt_text_attr_name", "text"),
70+
pred_text_attr_name=cfg.analyst.metric_calculator.get("pred_text_attr_name", "pred_text"),
6971
clean_groundtruth_text=cfg.analyst.metric_calculator.clean_groundtruth_text,
7072
langid=cfg.analyst.metric_calculator.langid,
7173
use_cer=cfg.analyst.metric_calculator.use_cer,

0 commit comments

Comments
 (0)