@@ -106,6 +106,7 @@ def convert_num_to_words(_str: str, langid: str = "en") -> str:
106
106
107
107
def cal_write_wer (
108
108
pred_manifest : str = None ,
109
+ gt_text_attr_name : str = "text" ,
109
110
pred_text_attr_name : str = "pred_text" ,
110
111
clean_groundtruth_text : bool = False ,
111
112
langid : str = 'en' ,
@@ -128,14 +129,17 @@ def cal_write_wer(
128
129
for line in fp :
129
130
sample = json .loads (line )
130
131
131
- if 'text' not in sample :
132
- logging .info (
133
- "ground-truth text is not present in manifest! Cannot calculate Word Error Rate. Returning!"
134
- )
132
+ if gt_text_attr_name not in sample :
133
+ if "text" in sample :
134
+ gt_text_attr_name = "text"
135
+ else :
136
+ logging .info (
137
+ f"ground-truth text attribute { gt_text_attr_name } is not present in manifest! Cannot calculate WER. Returning!"
138
+ )
135
139
return None , None , eval_metric
136
140
137
- hyp = sample [pred_text_attr_name ]
138
- ref = sample ['text' ]
141
+ hyp = sample [pred_text_attr_name ]. strip ()
142
+ ref = sample [gt_text_attr_name ]. strip ()
139
143
140
144
if clean_groundtruth_text :
141
145
ref = clean_label (ref , langid = langid )
@@ -211,13 +215,16 @@ def cal_write_text_metric(
211
215
sample = json .loads (line )
212
216
213
217
if gt_text_attr_name not in sample :
214
- logging .info (
215
- f"ground-truth text attribute { pred_text_attr_name } is not present in manifest! Cannot calculate { metric } . Returning!"
216
- )
218
+ if "text" in sample :
219
+ gt_text_attr_name = "text"
220
+ else :
221
+ logging .info (
222
+ f"ground-truth text attribute { gt_text_attr_name } is not present in manifest! Cannot calculate { metric } . Returning!"
223
+ )
217
224
return None , None , metric
218
225
219
- hyp = sample [pred_text_attr_name ]
220
- ref = sample ['text' ]
226
+ hyp = sample [pred_text_attr_name ]. strip ()
227
+ ref = sample [gt_text_attr_name ]. strip ()
221
228
222
229
if ignore_punctuation :
223
230
ref = remove_punctuations (ref , punctuations = punctuations )
@@ -227,13 +234,18 @@ def cal_write_text_metric(
227
234
ref = ref .lower ()
228
235
hyp = hyp .lower ()
229
236
230
- score = metric_calculator (hyp , ref ).item ()
237
+ if metric == 'bleu' :
238
+ score = metric_calculator ([hyp ], [[ref ]]).item ()
239
+ else :
240
+ score = metric_calculator (hyp , ref ).item ()
231
241
sample [metric ] = score # evaluatin metric, could be word error rate of character error rate
232
242
233
243
samples .append (sample )
234
244
hyps .append (hyp )
235
245
refs .append (ref )
236
246
247
+ if metric == 'bleu' :
248
+ refs = [[ref ] for ref in refs ]
237
249
total_score = metric_calculator (hyps , refs ).item ()
238
250
239
251
if not output_filename :
0 commit comments