Skip to content

Commit 5079a2a

Browse files
authored
Sync Math-verify (#535)
* update extraction match to reflect newest math-verify * revert symbols, improve sets handling * rm todo * fmt + remove empty excepts + bump l2s * fmt * docstring
1 parent 322a843 commit 5079a2a

File tree

6 files changed

+577
-126
lines changed

6 files changed

+577
-126
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ multilingual = [
109109
"jieba", # for chinese tokenizer
110110
"pyvi", # for vietnamese tokenizer
111111
]
112-
math = ["latex2sympy2_extended>=0.9.3"]
112+
math = ["latex2sympy2_extended==1.0.4"]
113113

114114
[project.urls]
115115
Homepage = "https://github.com/huggingface/lighteval"

src/lighteval/metrics/dynamic_metrics.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def multilingual_extractive_match_metric(
193193
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
194194
extraction_mode: Literal["first_match", "any_match"] = "any_match",
195195
precision: int = 6,
196+
timeout_seconds: int = 5,
196197
) -> SampleLevelMetric:
197198
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
198199
@@ -222,6 +223,8 @@ def multilingual_extractive_match_metric(
222223
223224
precision: int
224225
Number of decimal places to use when comparing numerical values. Defaults to 6.
226+
timeout_seconds: int
227+
Timeout for the extraction (each attempt) and comparison. Defaults to 5.
225228
226229
Returns:
227230
A sample level metric that extracts and compares mathematical expressions.
@@ -245,11 +248,12 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
245248
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)
246249

247250
extracted_predictions = [
248-
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode)
251+
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds)
249252
for pred in predictions
250253
]
251254
extracted_golds = [
252-
extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode) for gold in golds
255+
extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds)
256+
for gold in golds
253257
]
254258

255259
# Assert on empty gold and warn on empty pred
@@ -265,12 +269,19 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
265269
# We have to use timeout because the sypmy to str conversion can be very slow
266270
try:
267271
add_to_specifics_with_timeout(formatted_doc, extracted_predictions, extracted_golds)
268-
except: # noqa: E722
272+
except Exception: # noqa: E722
269273
logger.warning("Timeout when adding extracted predictions and golds to specific")
270274

271275
return aggregation_function(
272276
[
273-
(1.0 if any(compare_gold_target(gold, pred, precision) for gold in extracted_golds) else 0.0)
277+
(
278+
1.0
279+
if any(
280+
compare_gold_target(gold, pred, precision, timeout_seconds=timeout_seconds)
281+
for gold in extracted_golds
282+
)
283+
else 0.0
284+
)
274285
for pred in extracted_predictions
275286
]
276287
)

0 commit comments

Comments
 (0)