Skip to content

Commit 0e46269

Browse files
Math extraction - allow only trying the first match, more customizable latex extraction + bump deps (#522)
* extract matching * better docstring * lazy imports * bump up math * Update src/lighteval/metrics/dynamic_metrics.py Co-authored-by: Clémentine Fourrier <[email protected]> * fix pr commnets * Apply suggestions from code review Co-authored-by: Clémentine Fourrier <[email protected]> * rename comparisson -> comparison * fix expr numbers extraction with currency or units * add test for correct extraction of failed answer * bump of latex2sympy2 version, add new tests for extract metric * bump up latex2sympy + adjust latex target * revert gold target timoeut * remove dead comment 💀 * add doc --------- Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent cb075a5 commit 0e46269

File tree

3 files changed

+43
-11
lines changed

3 files changed

+43
-11
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ multilingual = [
109109
"jieba", # for chinese tokenizer
110110
"pyvi", # for vietnamese tokenizer
111111
]
112-
math = ["latex2sympy2_extended>=0.9.1"]
112+
math = ["latex2sympy2_extended>=0.9.3"]
113113

114114
[project.urls]
115115
Homepage = "https://github.com/huggingface/lighteval"

src/lighteval/metrics/dynamic_metrics.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def multilingual_extractive_match_metric(
191191
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
192192
aggregation_function: Callable[[list[float]], float] = max,
193193
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
194+
extraction_mode: Literal["first_match", "any_match"] = "any_match",
194195
precision: int = 6,
195196
) -> SampleLevelMetric:
196197
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
@@ -215,6 +216,10 @@ def multilingual_extractive_match_metric(
215216
How to perform extraction. Defaults to "first_match".
216217
- "no_fallback": Only use first successfully parsed matches
217218
- "first_match": Use the first successfully parsed match + first match irregardless the parsing success
219+
extraction_mode: Literal["first_match", "any_match"]
220+
- "first_match": Only tries to extract the first regex match if it fails no other matches are tried
221+
- "any_match": Tries to extract any regex match
222+
218223
precision: int
219224
Number of decimal places to use when comparing numerical values. Defaults to 6.
220225
@@ -240,9 +245,12 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
240245
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)
241246

242247
extracted_predictions = [
243-
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode) for pred in predictions
248+
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode)
249+
for pred in predictions
250+
]
251+
extracted_golds = [
252+
extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode) for gold in golds
244253
]
245-
extracted_golds = [extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode) for gold in golds]
246254

247255
# Assert on empty gold and warn on empty pred
248256
if any(len(g) == 0 for g in extracted_golds):

src/lighteval/metrics/utils/extractive_match_utils.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
# SOFTWARE.
2222

2323
import re
24-
from dataclasses import dataclass
24+
from dataclasses import dataclass, field
2525
from functools import lru_cache
2626
from itertools import groupby
27-
from typing import Literal, Sequence
27+
from typing import Any, Literal, Sequence
2828

2929
import sympy
3030
from sympy import Basic, MatrixBase, Number
@@ -39,17 +39,33 @@
3939
from lighteval.utils.timeout import timeout
4040

4141

42+
@requires_latex2sympy2_extended
43+
def latex_normalization_config_default_factory():
44+
from latex2sympy2_extended.latex2sympy2 import NormalizationConfig
45+
46+
return NormalizationConfig(
47+
basic_latex=True,
48+
units=True,
49+
malformed_operators=True,
50+
nits=True,
51+
boxed=True,
52+
equations=True,
53+
)
54+
55+
4256
@dataclass(frozen=True)
4357
class LatexExtractionConfig:
4458
"""Config for extracting latex from the prediction.
4559
4660
Attributes:
4761
try_extract_without_anchor (bool): Whether to try extracting latex without requiring specific anchors like "answer:" or "final answer is"
48-
enforce_boxed_match (bool): Whether to also consider extracting from plain \boxed{...} expressions
62+
boxed_match_priority (int): Priority of the boxed match regex (-1 never, 0 first, 55 after final answer: anchor, etc...)
63+
normalization_config (latex2sympy2_extended.latex2sympy2.NormalizationConfig): Normalization config to use for latex extraction
4964
"""
5065

5166
try_extract_without_anchor: bool = True
52-
enforce_boxed_match: bool = True
67+
boxed_match_priority: int = 55
68+
normalization_config: Any = field(default_factory=latex_normalization_config_default_factory)
5369

5470

5571
@dataclass(frozen=True)
@@ -187,9 +203,8 @@ def lazy_latex_regex(latex_config: LatexExtractionConfig, language: Language) ->
187203
if latex_config.try_extract_without_anchor:
188204
regexes.append((latex_re, 300))
189205

190-
# This ensures that boxed is matched right after the final answer xxxx
191-
if latex_config.enforce_boxed_match:
192-
regexes.append((latex_boxed, 55))
206+
if latex_config.boxed_match_priority >= 0:
207+
regexes.append((latex_boxed, latex_config.boxed_match_priority))
193208

194209
return [(re.compile(pattern, re.DOTALL), priority) for pattern, priority in regexes]
195210

@@ -387,6 +402,7 @@ def extract_target_from_pred(
387402
pred: str,
388403
target_res: list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]],
389404
fallback_mode: Literal["no_fallback", "first_match"] = "no_fallback",
405+
extraction_mode: Literal["first_match", "any_match"] = "any_match",
390406
):
391407
"""Extracts targets from a prediction string using regex patterns.
392408
Returns first sucesffuly extracted match.
@@ -397,6 +413,9 @@ def extract_target_from_pred(
397413
fallback_mode (Literal["no_fallback", "first_match"], optional): How to handle extraction failures. Defaults to "no_fallback".
398414
- "no_fallback": Return only successfully parsed match
399415
- "first_match": Additionaly Include the first string match no matter how parsing finished
416+
extraction_mode (Literal["first_match", "any_match"], optional): How to handle extraction failures. Defaults to "any_match".
417+
- "first_match": Only tries to extract the first match
418+
- "any_match": Tries to extract any match
400419
401420
Returns:
402421
list: List of extracted predictions, with first fallbac string appended if fallback_mode is "first_match"
@@ -410,6 +429,7 @@ def extract_target_from_pred(
410429
for target_patterns, target_type in target_res
411430
for pattern, priority in target_patterns
412431
]
432+
match_found = False
413433

414434
# Group patterns by priority using itertools.groupby
415435
for _, patterns_group in groupby(sorted(all_patterns, key=lambda x: x[2]), key=lambda x: x[2]):
@@ -426,6 +446,7 @@ def extract_target_from_pred(
426446
# Try to extract from each match, starting from rightmost
427447
for match, _, _, target_type in matches_with_pos:
428448
extracted_match, str_fallback = extract_match(match, target_type)
449+
match_found = True
429450

430451
if str_fallback:
431452
fallbacks.append(str_fallback)
@@ -434,8 +455,11 @@ def extract_target_from_pred(
434455
extracted_predictions.append(extracted_match)
435456
break
436457

458+
if extraction_mode == "first_match":
459+
break
460+
437461
# If we found something and we're in first_match mode, stop processing other priorities
438-
if extracted_predictions:
462+
if extracted_predictions or (match_found and extraction_mode == "first_match"):
439463
break
440464

441465
if fallback_mode == "first_match" and fallbacks:

0 commit comments

Comments
 (0)