Skip to content

Commit 626e8a9

Browse files
committed
Bugfix: WebSRC should be token-level F1 NOT character-level
1 parent eef3aeb commit 626e8a9

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

Diff for: lmms_eval/tasks/websrc/utils.py

+27-15
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def websrc_process_results(doc, results):
5050
"websrc_squad_f1": websrc_ans,
5151
"submission": {
5252
websrc_ans['question_id']: pred,
53-
},
53+
} if 'question_id' in websrc_ans else None
5454
}
5555

5656

@@ -122,27 +122,39 @@ def _normalize_str(string):
122122
# lower it
123123
string = string.lower()
124124

125-
# strip non-alphanumeric characters
126-
string = re.sub(r"[^a-zA-Z0-9]", "", string)
127-
128125
# strip leading and trailing whitespaces
129126
string = string.strip()
130127

131128
return string
132129

130+
def _tokenize(text):
131+
# Regex pattern to match words and isolate punctuation
132+
pattern = r'\w+|[^\w\s]'
133+
tokens = re.findall(pattern, text)
134+
return tokens
135+
136+
def _compute_f1(sa, sb):
137+
sa = _normalize_str(sa)
138+
sb = _normalize_str(sb)
139+
140+
sa = _tokenize(sa)
141+
sb = _tokenize(sb)
142+
143+
sa = set(sa)
144+
sb = set(sb)
145+
146+
if len(sa) == 0 or len(sb) == 0:
147+
return 0.0
148+
149+
comm = sa.intersection(sb)
150+
prec = len(comm) / len(sb)
151+
rec = len(comm) / len(sa)
152+
f1 = 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0
153+
return f1
154+
133155
judge_list = []
134156
for sample in samples:
135-
gold_i = set(_normalize_str(sample["answer"]))
136-
pred_i = set(_normalize_str( sample["parsed_pred"]))
137-
if len(pred_i) == 0:
138-
judge_list.append(0.0)
139-
continue
140-
141-
comm_i = gold_i.intersection(pred_i)
142-
prec_i = len(comm_i) / len(pred_i)
143-
rec_i = len(comm_i) / len(gold_i)
144-
f1_i = 2 * prec_i * rec_i / (prec_i + rec_i) if prec_i + rec_i > 0 else 0
145-
judge_list.append(f1_i)
157+
judge_list.append(_compute_f1(sample["answer"], sample["parsed_pred"]))
146158

147159
f1 = np.mean(judge_list)
148160
return judge_list, {"f1": f1}

0 commit comments

Comments
 (0)