Skip to content

Commit 743673a

Browse files
committed
2 parents c1a5472 + 95df9fe commit 743673a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2214
-129
lines changed

Diff for: README.md

+28-12
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,21 @@ We also provide the raw data exported from Weights & Biases for the detailed res
201201
- OKVQA Validation 2014 (ok_vqa_val2014)
202202
- POPE (pope)
203203
- RefCOCO (refcoco)
204-
- refcoco_seg_test
205-
- refcoco_seg_val
206-
- refcoco_seg_testA
207-
- refcoco_seg_testB
208-
- refcoco_bbox_test
209-
- refcoco_bbox_val
210-
- refcoco_bbox_testA
211-
- refcoco_bbox_testB
204+
- refcoco_seg
205+
- refcoco_seg_test
206+
- refcoco_seg_val
207+
- refcoco_seg_testA
208+
- refcoco_seg_testB
209+
- refcoco_bbox
210+
- refcoco_bbox_test
211+
- refcoco_bbox_val
212+
- refcoco_bbox_testA
213+
- refcoco_bbox_testB
214+
- refcoco_bbox_rec
215+
- refcoco_bbox_rec_test
216+
- refcoco_bbox_rec_val
217+
- refcoco_bbox_rec_testA
218+
- refcoco_bbox_rec_testB
212219
- RefCOCO+ (refcoco+)
213220
- refcoco+_seg
214221
- refcoco+_seg_val
@@ -218,11 +225,20 @@ We also provide the raw data exported from Weights & Biases for the detailed res
218225
- refcoco+_bbox_val
219226
- refcoco+_bbox_testA
220227
- refcoco+_bbox_testB
228+
- refcoco+_bbox_rec
229+
- refcoco+_bbox_rec_val
230+
- refcoco+_bbox_rec_testA
231+
- refcoco+_bbox_rec_testB
221232
- RefCOCOg (refcocog)
222-
- refcocog_seg_test
223-
- refcocog_seg_val
224-
- refcocog_bbox_test
225-
- refcocog_bbox_val
233+
- refcocog_seg
234+
- refcocog_seg_test
235+
- refcocog_seg_val
236+
- refcocog_bbox
237+
- refcocog_bbox_test
238+
- refcocog_bbox_val
239+
- refcocog_bbox_rec
240+
- refcocog_bbox_rec_test
241+
- refcocog_bbox_rec_val
226242
- ScienceQA (scienceqa_full)
227243
- ScienceQA Full (scienceqa)
228244
- ScienceQA IMG (scienceqa_img)

Diff for: lmms_eval/api/task.py

+27-12
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,22 @@ def _prepare_metric_and_aggregation(self):
678678

679679
@retry(stop=stop_after_attempt(5), wait=wait_fixed(2))
680680
def download(self, dataset_kwargs=None) -> None:
681+
# If the dataset is a video dataset,
682+
# Recursively search whether their is a zip and unzip it to the huggingface home
683+
if dataset_kwargs is not None and "video" in dataset_kwargs and dataset_kwargs["video"]:
684+
hf_home = os.environ["HF_HOME"]
685+
cache_dir = dataset_kwargs["cache_dir"]
686+
dataset_kwargs.pop("cache_dir")
687+
cache_dir = os.path.join(hf_home, cache_dir)
688+
cache_path = snapshot_download(repo_id=self.DATASET_PATH, repo_type="dataset")
689+
zip_files = glob(os.path.join(cache_path, "**/*.zip"), recursive=True)
690+
if not os.path.exists(cache_dir):
691+
for zip_file in zip_files:
692+
shutil.unpack_archive(zip_file, cache_dir)
693+
builder_script = dataset_kwargs["builder_script"]
694+
self.DATASET_PATH = os.path.join(cache_path, builder_script)
695+
dataset_kwargs.pop("video")
696+
dataset_kwargs.pop("builder_script")
681697
download_config = DownloadConfig()
682698
download_config.max_retries = dataset_kwargs.get("max_retries", 3) if dataset_kwargs is not None else 3
683699
download_config.num_proc = dataset_kwargs.get("num_proc", 8) if dataset_kwargs is not None else 8
@@ -687,12 +703,15 @@ def download(self, dataset_kwargs=None) -> None:
687703
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
688704
**dataset_kwargs if dataset_kwargs is not None else {},
689705
)
690-
self.dataset_no_image = datasets.load_dataset(
691-
path=self.DATASET_PATH,
692-
name=self.DATASET_NAME,
693-
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
694-
**dataset_kwargs if dataset_kwargs is not None else {},
695-
)
706+
if self.config.process_docs is not None:
707+
for split in self.dataset:
708+
if split in [
709+
self.config.training_split, self.config.validation_split, self.config.test_split, self.config.fewshot_split
710+
]:
711+
self.dataset[split] = self.config.process_docs(self.dataset[split])
712+
713+
# copy dataset, remove image features
714+
self.dataset_no_image = self.dataset.copy()
696715
for doc_name in self.dataset_no_image:
697716
remove_cols = []
698717
features = self.dataset_no_image[doc_name].features
@@ -725,20 +744,14 @@ def has_test_docs(self) -> bool:
725744

726745
def training_docs(self) -> datasets.Dataset:
727746
if self.has_training_docs():
728-
if self.config.process_docs is not None:
729-
return self.config.process_docs(self.dataset[self.config.training_split])
730747
return self.dataset[self.config.training_split]
731748

732749
def validation_docs(self) -> datasets.Dataset:
733750
if self.has_validation_docs():
734-
if self.config.process_docs is not None:
735-
return self.config.process_docs(self.dataset[self.config.validation_split])
736751
return self.dataset[self.config.validation_split]
737752

738753
def test_docs(self) -> datasets.Dataset:
739754
if self.has_test_docs():
740-
if self.config.process_docs is not None:
741-
return self.config.process_docs(self.dataset[self.config.test_split])
742755
return self.dataset[self.config.test_split]
743756

744757
def fewshot_docs(self):
@@ -973,6 +986,8 @@ def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Inst
973986
return Instance(request_type=self.OUTPUT_TYPE, arguments=arguments, idx=0, **kwargs)
974987

975988
def process_results(self, doc, results):
989+
if self.OUTPUT_TYPE == "generate_until":
990+
results[0] = results[0].strip()
976991
if callable(self.config.process_results):
977992
return self.config.process_results(doc, results)
978993

Diff for: lmms_eval/filters/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from lmms_eval.api.filter import FilterEnsemble
1+
from lmms_eval.api.filter import FilterEnsemble, Filter
22
from . import selection
33
from . import extraction
44
from . import transformation
@@ -13,6 +13,7 @@
1313
"lowercase": transformation.LowercaseFilter,
1414
"uppercase": transformation.UppercaseFilter,
1515
"map": transformation.MapFilter,
16+
"multi_choice_regex": extraction.MultiChoiceRegexFilter,
1617
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
1718
# that takes an input and returns a scalar and then should select the max reward,
1819
# or should implement different filters for different ways of handling a reward model's inference.

Diff for: lmms_eval/filters/extraction.py

+170-16
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,47 @@
11
import re
2-
2+
import sys
3+
import unicodedata
34
from lmms_eval.api.filter import Filter
45

56

7+
class WhitespaceFilter(Filter):
8+
""" """
9+
10+
def __init__(self) -> None:
11+
pass
12+
13+
def apply(self, resps, docs):
14+
def filter_set(inst):
15+
filtered_resp = []
16+
for resp in inst:
17+
if resp.startswith(" "):
18+
resp = resp[1:]
19+
20+
filtered_resp.append(resp)
21+
22+
return filtered_resp
23+
24+
filtered_resps = [filter_set(resp) for resp in resps]
25+
26+
return filtered_resps
27+
28+
629
class RegexFilter(Filter):
730
""" """
831

9-
def __init__(self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", fallback: str = "[invalid]") -> None:
32+
def __init__(
33+
self,
34+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
35+
group_select=0,
36+
fallback: str = "[invalid]",
37+
) -> None:
1038
"""
1139
pass a string `regex` to run `re.compile(r"regex")` on.
1240
`fallback` defines the output returned if no matches for the regex are located.
1341
"""
1442
self.regex_pattern = regex_pattern
1543
self.regex = re.compile(regex_pattern)
44+
self.group_select = group_select
1645
self.fallback = fallback
1746

1847
def apply(self, resps, docs):
@@ -23,9 +52,12 @@ def apply(self, resps, docs):
2352
def filter_set(inst):
2453
filtered = []
2554
for resp in inst:
26-
match = self.regex.search(resp)
55+
match = self.regex.findall(resp)
2756
if match:
28-
match = match.group(1).strip()
57+
match = match[self.group_select]
58+
if isinstance(match, tuple):
59+
match = [m for m in match if m][0]
60+
match = match.strip()
2961
else:
3062
match = self.fallback
3163
filtered.append(match)
@@ -38,23 +70,145 @@ def filter_set(inst):
3870
return filtered_resps
3971

4072

41-
class WhitespaceFilter(Filter):
42-
""" """
73+
class MultiChoiceRegexFilter(RegexFilter):
74+
"""
75+
A filter used to extract a model's answer on multiple choice questions with
76+
letter answers. assumes each document has a "choices" field
77+
containing the list of answer choices and that the answer label symbols
78+
are of the form (A), (B), (C), ... or A, B, C.
79+
"""
4380

44-
def __init__(self) -> None:
45-
pass
81+
def __init__(
82+
self,
83+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
84+
group_select=0,
85+
fallback: str = "[invalid]",
86+
ignore_case=False,
87+
ignore_punctuation=False,
88+
regexes_to_ignore=None,
89+
) -> None:
90+
"""
91+
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
92+
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
93+
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
94+
group_select: Selects the (group_select)th match from the findall result.
95+
ignore_case: Ignores the case during step 1 matching
96+
ignore_punctuation: Remove the punctuation during step 1 matching
97+
regexes_to_ignore: Remove these regexes during step 1 matching
98+
"""
99+
super().__init__(regex_pattern, group_select, fallback)
100+
self.ignore_case = ignore_case
101+
self.ignore_punctuation = ignore_punctuation
102+
self.regexes_to_ignore = regexes_to_ignore
46103

47104
def apply(self, resps, docs):
48-
def filter_set(inst):
49-
filtered_resp = []
50-
for resp in inst:
51-
if resp.startswith(" "):
52-
resp = resp[1:]
105+
# here, we assume we have a list, in which each element is
106+
# a list of model responses for some particular input/target pair.
107+
# so we process each of these (same input/target response sets)
108+
# independently (and keep them a list.)
53109

54-
filtered_resp.append(resp)
110+
def find_match(regex, resp, convert_dict={}):
111+
match = regex.findall(resp)
112+
if match:
113+
match = match[self.group_select]
114+
if isinstance(match, tuple):
115+
match = [m for m in match if m][0]
116+
match = match.strip()
117+
if match and match in convert_dict:
118+
match = convert_dict[match]
119+
return match
55120

56-
return filtered_resp
121+
punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P"))
57122

58-
filtered_resps = [filter_set(resp) for resp in resps]
123+
def filter_ignores(st):
124+
if self.regexes_to_ignore is not None:
125+
for s in self.regexes_to_ignore:
126+
st = re.sub(s, "", st)
127+
128+
if self.ignore_case:
129+
st = st.lower()
130+
131+
if self.ignore_punctuation:
132+
# https://stackoverflow.com/a/266162
133+
st = st.translate(punct_tbl)
134+
return st
135+
136+
filtered_resps = []
137+
138+
for r, doc in zip(resps, docs):
139+
fallback_regexes = []
140+
choice_to_alpha = {}
141+
next_alpha = "A"
142+
143+
without_paren_fallback_regexes = []
144+
without_paren_to_target = {}
145+
146+
choices = doc["choices"]
147+
for c in choices:
148+
m = filter_ignores(c.strip())
149+
fallback_regexes.append(f"{re.escape(m)}")
150+
choice_to_alpha[m] = f"({next_alpha})"
151+
152+
without_paren_fallback_regexes.append(next_alpha)
153+
without_paren_to_target[next_alpha] = f"({next_alpha})"
154+
155+
next_alpha = chr(ord(next_alpha) + 1)
156+
fallback_regex = re.compile("|".join(fallback_regexes))
157+
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
158+
without_paren_fallback_regex = re.compile(f":[\s]*({without_paren_fallback_regex})")
159+
160+
filtered = []
161+
for resp in r:
162+
match = find_match(self.regex, resp)
163+
if not match:
164+
match = find_match(fallback_regex, filter_ignores(resp), choice_to_alpha)
165+
if not match:
166+
match = find_match(without_paren_fallback_regex, resp, without_paren_to_target)
167+
if not match:
168+
match = self.fallback
169+
filtered.append(match)
170+
filtered_resps.append(filtered)
59171

60172
return filtered_resps
173+
174+
175+
class ExtendedRegexFilter(RegexFilter):
176+
punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P"))
177+
178+
def __init__(
179+
self,
180+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
181+
group_select=0,
182+
fallback: str = "[invalid]",
183+
ignore_case=False,
184+
ignore_punctuation=False,
185+
regexes_to_ignore=None,
186+
) -> None:
187+
super().__init__(regex_pattern, group_select, fallback)
188+
self.ignore_case = ignore_case
189+
self.ignore_punctuation = ignore_punctuation
190+
self.regexes_to_ignore = regexes_to_ignore
191+
192+
def filter_ignores(self, st):
193+
if self.regexes_to_ignore is not None:
194+
for s in self.regexes_to_ignore:
195+
st = re.sub(s, "", st)
196+
197+
if self.ignore_case:
198+
st = st.lower()
199+
200+
if self.ignore_punctuation:
201+
# https://stackoverflow.com/a/266162
202+
st = st.translate(self.punct_tbl)
203+
return st
204+
205+
def find_match(self, regex, resp, convert_dict={}):
206+
match = regex.findall(resp)
207+
if match:
208+
match = match[self.group_select]
209+
if isinstance(match, tuple):
210+
match = [m for m in match if m][0]
211+
match = match.strip()
212+
if match and match in convert_dict:
213+
match = convert_dict[match]
214+
return match

Diff for: lmms_eval/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
AVAILABLE_MODELS = {
44
"llava": "Llava",
5+
"llava_hf": "LlavaHf",
6+
"llava_sglang": "LlavaSglang",
57
"qwen_vl": "Qwen_VL",
68
"fuyu": "Fuyu",
79
"gpt4v": "GPT4V",

0 commit comments

Comments
 (0)