Skip to content

Commit a5c1869

Browse files
authored
Merge pull request #137 from shuyansy/main
add MLVU task
2 parents 2ebec77 + 557083a commit a5c1869

File tree

3 files changed

+145
-0
lines changed

3 files changed

+145
-0
lines changed
3.13 KB
Binary file not shown.

Diff for: lmms_eval/tasks/mlvu/mlvu.yaml

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
dataset_path: sy1998/temp
2+
dataset_kwargs:
3+
token: True
4+
cache_dir: mlvu
5+
video: True
6+
task: mlvu
7+
test_split: test
8+
output_type: generate_until
9+
doc_to_visual: !function utils.mlvu_doc_to_visual
10+
doc_to_text: !function utils.mlvu_doc_to_text
11+
doc_to_target: "answer"
12+
# The return value of process_results will be used by metrics
13+
process_results: !function utils.mlvu_process_results
14+
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
15+
metric_list:
16+
- metric: mlvu_percetion_score
17+
aggregation: !function utils.mlvu_aggregate_results
18+
higher_is_better: true
19+
20+
21+

Diff for: lmms_eval/tasks/mlvu/utils.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from collections import defaultdict
2+
import os
3+
import datetime
4+
import json
5+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
6+
from pathlib import Path
7+
import yaml
8+
import sys
9+
from typing import List, Dict, Optional, Union
10+
import re
11+
import cv2
12+
import numpy as np
13+
from loguru import logger as eval_logger
14+
15+
TASK_TYPES = [
16+
"TR",
17+
"AR",
18+
"VS",
19+
"NQA",
20+
"ER",
21+
"PQA",
22+
"SSC",
23+
"AO",
24+
"AC"
25+
]
26+
27+
28+
29+
hf_home = os.getenv("HF_HOME", "./~/.cache/huggingface")
30+
base_cache_dir = os.path.expanduser(hf_home)
31+
32+
with open(Path(__file__).parent / "mlvu.yaml", "r") as f:
33+
raw_data = f.readlines()
34+
safe_data = []
35+
for i, line in enumerate(raw_data):
36+
# remove function definition since yaml load cannot handle it
37+
if "!function" not in line:
38+
safe_data.append(line)
39+
cache_name = yaml.safe_load("".join(safe_data))["dataset_kwargs"]["cache_dir"]
40+
41+
42+
43+
def mlvu_doc_to_visual(doc):
44+
45+
cache_dir = os.path.join(base_cache_dir, cache_name)
46+
video_path = doc["video_name"]
47+
video_path = os.path.join(cache_dir, video_path)
48+
if os.path.exists(video_path):
49+
video_path = video_path
50+
else:
51+
sys.exit(f"video path:{video_path} does not exist, please check")
52+
return [video_path]
53+
54+
55+
def mlvu_doc_to_text(doc, model_specific_prompt_kwargs=None):
56+
# option_prompt="Carefully watch this video and pay attention to every detail. Based on your observations, select the best option that accurately addresses the question."
57+
option_prompt=""
58+
question = doc["question"] + "\nOnly give the best option.\n"
59+
full_prompt=option_prompt+"\n"+question+"\n"+"Best option: ("
60+
return full_prompt
61+
62+
63+
def extract_characters_regex(s):
64+
s = s.strip()
65+
if ")" in s:
66+
index=s.index(")")
67+
pred=s[index-1:index]
68+
return pred
69+
else:
70+
return s
71+
72+
def mlvu_process_results(doc, results):
73+
"""
74+
Args:
75+
doc: a instance of the eval dataset
76+
results: [pred]
77+
Returns:
78+
a dictionary with key: metric name (in this case videomme score), value: metric value
79+
"""
80+
pred = results[0]
81+
# print("****************",pred)
82+
pred_ans = extract_characters_regex(pred)
83+
84+
task_type = doc["task_type"]
85+
data_dict = {"question_id": doc["question"], "task_type": task_type, "pred_answer": pred_ans, "answer": doc["answer"]}
86+
87+
return {f"mlvu_percetion_score": data_dict}
88+
89+
90+
def mlvu_aggregate_results(results):
91+
"""
92+
Args:
93+
results: a list of values returned by process_results
94+
Returns:
95+
A score
96+
"""
97+
category2score = {}
98+
for task_type in TASK_TYPES:
99+
category2score[task_type] = {"correct": 0, "answered": 0}
100+
101+
102+
for result in results:
103+
task_type = result["task_type"]
104+
category2score[task_type]["answered"] += 1
105+
category2score[task_type]["correct"] += result["pred_answer"] == result["answer"]
106+
107+
108+
for task_cate in TASK_TYPES:
109+
total_correct = 0
110+
total_answered = 0
111+
for k, v in category2score.items():
112+
if task_cate in k:
113+
total_correct += v["correct"]
114+
total_answered += v["answered"]
115+
eval_logger.info(f"Evaluation on Task Categories: {task_cate}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%")
116+
117+
total_correct = 0
118+
total_answered = 0
119+
for k, v in category2score.items():
120+
total_correct += v["correct"]
121+
total_answered += v["answered"]
122+
eval_logger.info(f"Overall Performance: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%")
123+
124+
return 100 * total_correct / total_answered if total_answered > 0 else 0

0 commit comments

Comments
 (0)