1
1
import re
2
-
2
+ import sys
3
+ import unicodedata
3
4
from lmms_eval .api .filter import Filter
4
5
5
6
7
+ class WhitespaceFilter (Filter ):
8
+ """ """
9
+
10
+ def __init__ (self ) -> None :
11
+ pass
12
+
13
+ def apply (self , resps , docs ):
14
+ def filter_set (inst ):
15
+ filtered_resp = []
16
+ for resp in inst :
17
+ if resp .startswith (" " ):
18
+ resp = resp [1 :]
19
+
20
+ filtered_resp .append (resp )
21
+
22
+ return filtered_resp
23
+
24
+ filtered_resps = [filter_set (resp ) for resp in resps ]
25
+
26
+ return filtered_resps
27
+
28
+
6
29
class RegexFilter (Filter ):
7
30
""" """
8
31
9
- def __init__ (self , regex_pattern : str = r"#### (\-?[0-9\.\,]+)" , fallback : str = "[invalid]" ) -> None :
32
+ def __init__ (
33
+ self ,
34
+ regex_pattern : str = r"#### (\-?[0-9\.\,]+)" ,
35
+ group_select = 0 ,
36
+ fallback : str = "[invalid]" ,
37
+ ) -> None :
10
38
"""
11
39
pass a string `regex` to run `re.compile(r"regex")` on.
12
40
`fallback` defines the output returned if no matches for the regex are located.
13
41
"""
14
42
self .regex_pattern = regex_pattern
15
43
self .regex = re .compile (regex_pattern )
44
+ self .group_select = group_select
16
45
self .fallback = fallback
17
46
18
47
def apply (self , resps , docs ):
@@ -23,9 +52,12 @@ def apply(self, resps, docs):
23
52
def filter_set (inst ):
24
53
filtered = []
25
54
for resp in inst :
26
- match = self .regex .search (resp )
55
+ match = self .regex .findall (resp )
27
56
if match :
28
- match = match .group (1 ).strip ()
57
+ match = match [self .group_select ]
58
+ if isinstance (match , tuple ):
59
+ match = [m for m in match if m ][0 ]
60
+ match = match .strip ()
29
61
else :
30
62
match = self .fallback
31
63
filtered .append (match )
@@ -38,23 +70,145 @@ def filter_set(inst):
38
70
return filtered_resps
39
71
40
72
41
- class WhitespaceFilter (Filter ):
42
- """ """
73
+ class MultiChoiceRegexFilter (RegexFilter ):
74
+ """
75
+ A filter used to extract a model's answer on multiple choice questions with
76
+ letter answers. assumes each document has a "choices" field
77
+ containing the list of answer choices and that the answer label symbols
78
+ are of the form (A), (B), (C), ... or A, B, C.
79
+ """
43
80
44
- def __init__ (self ) -> None :
45
- pass
81
+ def __init__ (
82
+ self ,
83
+ regex_pattern : str = r"#### (\-?[0-9\.\,]+)" ,
84
+ group_select = 0 ,
85
+ fallback : str = "[invalid]" ,
86
+ ignore_case = False ,
87
+ ignore_punctuation = False ,
88
+ regexes_to_ignore = None ,
89
+ ) -> None :
90
+ """
91
+ regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
92
+ - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
93
+ - step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
94
+ group_select: Selects the (group_select)th match from the findall result.
95
+ ignore_case: Ignores the case during step 1 matching
96
+ ignore_punctuation: Remove the punctuation during step 1 matching
97
+ regexes_to_ignore: Remove these regexes during step 1 matching
98
+ """
99
+ super ().__init__ (regex_pattern , group_select , fallback )
100
+ self .ignore_case = ignore_case
101
+ self .ignore_punctuation = ignore_punctuation
102
+ self .regexes_to_ignore = regexes_to_ignore
46
103
47
104
def apply (self , resps , docs ):
48
- def filter_set (inst ):
49
- filtered_resp = []
50
- for resp in inst :
51
- if resp .startswith (" " ):
52
- resp = resp [1 :]
105
+ # here, we assume we have a list, in which each element is
106
+ # a list of model responses for some particular input/target pair.
107
+ # so we process each of these (same input/target response sets)
108
+ # independently (and keep them a list.)
53
109
54
- filtered_resp .append (resp )
110
+ def find_match (regex , resp , convert_dict = {}):
111
+ match = regex .findall (resp )
112
+ if match :
113
+ match = match [self .group_select ]
114
+ if isinstance (match , tuple ):
115
+ match = [m for m in match if m ][0 ]
116
+ match = match .strip ()
117
+ if match and match in convert_dict :
118
+ match = convert_dict [match ]
119
+ return match
55
120
56
- return filtered_resp
121
+ punct_tbl = dict . fromkeys ( i for i in range ( sys . maxunicode ) if unicodedata . category ( chr ( i )). startswith ( "P" ))
57
122
58
- filtered_resps = [filter_set (resp ) for resp in resps ]
123
+ def filter_ignores (st ):
124
+ if self .regexes_to_ignore is not None :
125
+ for s in self .regexes_to_ignore :
126
+ st = re .sub (s , "" , st )
127
+
128
+ if self .ignore_case :
129
+ st = st .lower ()
130
+
131
+ if self .ignore_punctuation :
132
+ # https://stackoverflow.com/a/266162
133
+ st = st .translate (punct_tbl )
134
+ return st
135
+
136
+ filtered_resps = []
137
+
138
+ for r , doc in zip (resps , docs ):
139
+ fallback_regexes = []
140
+ choice_to_alpha = {}
141
+ next_alpha = "A"
142
+
143
+ without_paren_fallback_regexes = []
144
+ without_paren_to_target = {}
145
+
146
+ choices = doc ["choices" ]
147
+ for c in choices :
148
+ m = filter_ignores (c .strip ())
149
+ fallback_regexes .append (f"{ re .escape (m )} " )
150
+ choice_to_alpha [m ] = f"({ next_alpha } )"
151
+
152
+ without_paren_fallback_regexes .append (next_alpha )
153
+ without_paren_to_target [next_alpha ] = f"({ next_alpha } )"
154
+
155
+ next_alpha = chr (ord (next_alpha ) + 1 )
156
+ fallback_regex = re .compile ("|" .join (fallback_regexes ))
157
+ without_paren_fallback_regex = "|" .join (without_paren_fallback_regexes )
158
+ without_paren_fallback_regex = re .compile (f":[\s]*({ without_paren_fallback_regex } )" )
159
+
160
+ filtered = []
161
+ for resp in r :
162
+ match = find_match (self .regex , resp )
163
+ if not match :
164
+ match = find_match (fallback_regex , filter_ignores (resp ), choice_to_alpha )
165
+ if not match :
166
+ match = find_match (without_paren_fallback_regex , resp , without_paren_to_target )
167
+ if not match :
168
+ match = self .fallback
169
+ filtered .append (match )
170
+ filtered_resps .append (filtered )
59
171
60
172
return filtered_resps
173
+
174
+
175
+ class ExtendedRegexFilter (RegexFilter ):
176
+ punct_tbl = dict .fromkeys (i for i in range (sys .maxunicode ) if unicodedata .category (chr (i )).startswith ("P" ))
177
+
178
+ def __init__ (
179
+ self ,
180
+ regex_pattern : str = r"#### (\-?[0-9\.\,]+)" ,
181
+ group_select = 0 ,
182
+ fallback : str = "[invalid]" ,
183
+ ignore_case = False ,
184
+ ignore_punctuation = False ,
185
+ regexes_to_ignore = None ,
186
+ ) -> None :
187
+ super ().__init__ (regex_pattern , group_select , fallback )
188
+ self .ignore_case = ignore_case
189
+ self .ignore_punctuation = ignore_punctuation
190
+ self .regexes_to_ignore = regexes_to_ignore
191
+
192
+ def filter_ignores (self , st ):
193
+ if self .regexes_to_ignore is not None :
194
+ for s in self .regexes_to_ignore :
195
+ st = re .sub (s , "" , st )
196
+
197
+ if self .ignore_case :
198
+ st = st .lower ()
199
+
200
+ if self .ignore_punctuation :
201
+ # https://stackoverflow.com/a/266162
202
+ st = st .translate (self .punct_tbl )
203
+ return st
204
+
205
+ def find_match (self , regex , resp , convert_dict = {}):
206
+ match = regex .findall (resp )
207
+ if match :
208
+ match = match [self .group_select ]
209
+ if isinstance (match , tuple ):
210
+ match = [m for m in match if m ][0 ]
211
+ match = match .strip ()
212
+ if match and match in convert_dict :
213
+ match = convert_dict [match ]
214
+ return match
0 commit comments