Skip to content

Commit 955bd06

Browse files
committed
Init webSRC
1 parent c4e9dd9 commit 955bd06

File tree

4 files changed

+163
-0
lines changed

4 files changed

+163
-0
lines changed

Diff for: lmms_eval/tasks/websrc/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# WebSRC

Diff for: lmms_eval/tasks/websrc/utils.py

+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from collections import defaultdict
2+
import re
3+
import ast
4+
import base64
5+
import io
6+
import random
7+
import numpy as np
8+
import os
9+
import json
10+
import logging
11+
from PIL import Image
12+
13+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
14+
15+
lmms_logger = logging.getLogger("lmms-eval")
16+
17+
OPEN_ENDED_PROMPT = "Answer the question using a single word or phrase."
18+
19+
20+
def construct_prompt(doc):
21+
question = doc["question"]
22+
# question = f"{question}\n{OPEN_ENDED_PROMPT}"
23+
question = f"{OPEN_ENDED_PROMPT}\n{question}"
24+
return question
25+
26+
27+
def websrc_doc_to_text(doc):
28+
question = construct_prompt(doc)
29+
return question
30+
31+
32+
def websrc_doc_to_visual(doc):
33+
img_bs64 = doc["image"]
34+
img = Image.open(io.BytesIO(base64.b64decode(img_bs64)))
35+
del doc['image']
36+
return [img]
37+
38+
39+
def websrc_process_results(doc, results):
40+
pred = results[0]
41+
parsed_pred = pred
42+
id = doc["page_id"]
43+
websrc_ans = {"id": id, "domain": doc['domain'], "answer": doc["answer"], "parsed_pred": parsed_pred}
44+
return {
45+
"websrc_squad_f1": websrc_ans,
46+
"submission": {
47+
id: pred,
48+
},
49+
}
50+
51+
52+
def websrc_test_aggregate_results_for_submission(results, args):
53+
path = generate_submission_file("websrc_test_for_submission.json", args)
54+
with open(path, "w") as f:
55+
json.dump(results, f)
56+
lmms_logger.info(f"Results saved to {path}.")
57+
58+
59+
def websrc_aggregate_results(results):
60+
evaluation_result = {}
61+
62+
# Group results by domain
63+
subset_to_eval_samples = defaultdict(list)
64+
for result in results:
65+
subset_to_eval_samples[result["domain"]].append(result)
66+
67+
# Evaluate each domain
68+
for subset, sub_eval_samples in subset_to_eval_samples.items():
69+
judge_dict, metric_dict = evaluate_websrc(sub_eval_samples)
70+
metric_dict.update({"num_example": len(sub_eval_samples)})
71+
evaluation_result[subset] = metric_dict
72+
73+
# Aggregate results for all domains
74+
printable_results = {}
75+
for domain in DOMAINS:
76+
if domain not in evaluation_result:
77+
continue
78+
printable_results[domain] = {
79+
"num": int(evaluation_result[domain]["num_example"]),
80+
"f1": round(evaluation_result[domain]["f1"], 3),
81+
}
82+
all_ins_f1 = np.sum([cat_results["f1"] * cat_results["num_example"] for cat_results in evaluation_result.values()]) / sum(
83+
[cat_results["num_example"] for cat_results in evaluation_result.values()]
84+
)
85+
printable_results["Overall"] = {
86+
"num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]),
87+
"f1": round(all_ins_f1, 3),
88+
}
89+
print(printable_results)
90+
return printable_results["Overall"]["f1"]
91+
92+
93+
##################
94+
# Helper functions written by official MMMU repo.
95+
##################
96+
DOMAINS = [
97+
'auto',
98+
'book',
99+
'camera',
100+
'game',
101+
'jobs',
102+
'movie',
103+
'phone',
104+
'restaurant',
105+
'sports',
106+
'university',
107+
'hotel',
108+
]
109+
110+
111+
def evaluate_websrc(samples):
112+
113+
def _normalize_str(string):
114+
# lower it
115+
string = string.lower()
116+
117+
# strip non-alphanumeric characters
118+
string = re.sub(r"[^a-zA-Z0-9]", "", string)
119+
120+
# strip leading and trailing whitespaces
121+
string = string.strip()
122+
123+
return string
124+
125+
judge_list = []
126+
for sample in samples:
127+
gold_i = set(_normalize_str(sample["answer"]))
128+
pred_i = set(_normalize_str( sample["parsed_pred"]))
129+
if len(pred_i) == 0:
130+
judge_list.append(0.0)
131+
continue
132+
133+
comm_i = gold_i.intersection(pred_i)
134+
prec_i = len(comm_i) / len(pred_i)
135+
rec_i = len(comm_i) / len(gold_i)
136+
f1_i = 2 * prec_i * rec_i / (prec_i + rec_i) if prec_i + rec_i > 0 else 0
137+
judge_list.append(f1_i)
138+
139+
f1 = np.mean(judge_list)
140+
return judge_list, {"f1": f1}

Diff for: lmms_eval/tasks/websrc/websrc.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
group: websrc
2+
task:
3+
- websrc_val

Diff for: lmms_eval/tasks/websrc/websrc_val.yaml

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
dataset_path: rootsautomation/websrc
2+
task: "websrc_val"
3+
test_split: dev
4+
output_type: generate_until
5+
doc_to_visual: !function utils.websrc_doc_to_visual
6+
doc_to_text: !function utils.websrc_doc_to_text
7+
doc_to_target: "answer"
8+
# The return value of process_results will be used by metrics
9+
process_results: !function utils.websrc_process_results
10+
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
11+
generation_kwargs:
12+
max_new_tokens: 16
13+
image_aspect_ratio: pad
14+
metric_list:
15+
- metric: websrc_squad_f1
16+
aggregation: !function utils.websrc_aggregate_results
17+
higher_is_better: true
18+
metadata:
19+
- version: 0.0

0 commit comments

Comments
 (0)