@@ -50,7 +50,7 @@ def websrc_process_results(doc, results):
50
50
"websrc_squad_f1" : websrc_ans ,
51
51
"submission" : {
52
52
websrc_ans ['question_id' ]: pred ,
53
- },
53
+ } if 'question_id' in websrc_ans else None
54
54
}
55
55
56
56
@@ -122,27 +122,39 @@ def _normalize_str(string):
122
122
# lower it
123
123
string = string .lower ()
124
124
125
- # strip non-alphanumeric characters
126
- string = re .sub (r"[^a-zA-Z0-9]" , "" , string )
127
-
128
125
# strip leading and trailing whitespaces
129
126
string = string .strip ()
130
127
131
128
return string
132
129
130
+ def _tokenize (text ):
131
+ # Regex pattern to match words and isolate punctuation
132
+ pattern = r'\w+|[^\w\s]'
133
+ tokens = re .findall (pattern , text )
134
+ return tokens
135
+
136
+ def _compute_f1 (sa , sb ):
137
+ sa = _normalize_str (sa )
138
+ sb = _normalize_str (sb )
139
+
140
+ sa = _tokenize (sa )
141
+ sb = _tokenize (sb )
142
+
143
+ sa = set (sa )
144
+ sb = set (sb )
145
+
146
+ if len (sa ) == 0 or len (sb ) == 0 :
147
+ return 0.0
148
+
149
+ comm = sa .intersection (sb )
150
+ prec = len (comm ) / len (sb )
151
+ rec = len (comm ) / len (sa )
152
+ f1 = 2 * prec * rec / (prec + rec ) if prec + rec > 0 else 0
153
+ return f1
154
+
133
155
judge_list = []
134
156
for sample in samples :
135
- gold_i = set (_normalize_str (sample ["answer" ]))
136
- pred_i = set (_normalize_str ( sample ["parsed_pred" ]))
137
- if len (pred_i ) == 0 :
138
- judge_list .append (0.0 )
139
- continue
140
-
141
- comm_i = gold_i .intersection (pred_i )
142
- prec_i = len (comm_i ) / len (pred_i )
143
- rec_i = len (comm_i ) / len (gold_i )
144
- f1_i = 2 * prec_i * rec_i / (prec_i + rec_i ) if prec_i + rec_i > 0 else 0
145
- judge_list .append (f1_i )
157
+ judge_list .append (_compute_f1 (sample ["answer" ], sample ["parsed_pred" ]))
146
158
147
159
f1 = np .mean (judge_list )
148
160
return judge_list , {"f1" : f1 }
0 commit comments