Skip to content

Commit 4561df4

Browse files
Extractive Match metric (#495)
* extract matching * better docstring * lazy imports * bump up math * Update src/lighteval/metrics/dynamic_metrics.py Co-authored-by: Clémentine Fourrier <[email protected]> * fix pr commnets * Apply suggestions from code review Co-authored-by: Clémentine Fourrier <[email protected]> * rename comparisson -> comparison --------- Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent 3239838 commit 4561df4

File tree

8 files changed

+2052
-3
lines changed

8 files changed

+2052
-3
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ tensorboardX = ["tensorboardX"]
9595
vllm = ["vllm", "ray", "more_itertools"]
9696
quality = ["ruff==v0.2.2","pre-commit"]
9797
tests = ["pytest==7.4.0"]
98-
dev = ["lighteval[accelerate,quality,tests,multilingual]"]
98+
dev = ["lighteval[accelerate,quality,tests,multilingual,math]"]
9999
docs = ["hf-doc-builder", "watchdog"]
100100
extended_tasks = [
101101
"langdetect", # ifeval
@@ -109,6 +109,7 @@ multilingual = [
109109
"jieba", # for chinese tokenizer
110110
"pyvi", # for vietnamese tokenizer
111111
]
112+
math = ["latex2sympy2_extended>=0.9.0"]
112113

113114
[project.urls]
114115
Homepage = "https://github.com/huggingface/lighteval"

src/lighteval/metrics/dynamic_metrics.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23-
from typing import Callable, Literal
23+
import logging
24+
from typing import Callable, Literal, Sequence
2425

2526
import numpy as np
2627

@@ -37,8 +38,22 @@
3738
LogProbTokenNorm,
3839
get_multilingual_normalizer,
3940
)
41+
from lighteval.metrics.utils.extractive_match_utils import ( # noqa: F401
42+
ExprExtractionConfig,
43+
ExtractionTarget,
44+
IndicesExtractionConfig,
45+
LatexExtractionConfig,
46+
extract_target_from_pred,
47+
get_extraction_regexes,
48+
)
49+
from lighteval.metrics.utils.math_comparison import compare_gold_target
4050
from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric
51+
from lighteval.tasks.requests import Doc
4152
from lighteval.utils.language import Language
53+
from lighteval.utils.timeout import timeout
54+
55+
56+
logger = logging.getLogger(__name__)
4257

4358

4459
def loglikelihood_acc_metric(normalization: LogProbNormalization | None = None) -> SampleLevelMetric:
@@ -168,3 +183,94 @@ def multilingual_quasi_exact_match_metric(
168183
corpus_level_fn=np.mean,
169184
higher_is_better=True,
170185
)
186+
187+
188+
def multilingual_extractive_match_metric(
189+
language: Language = Language.ENGLISH,
190+
gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
191+
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
192+
aggregation_function: Callable[[list[float]], float] = max,
193+
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
194+
precision: int = 6,
195+
) -> SampleLevelMetric:
196+
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
197+
198+
Known issues:
199+
- If the task is to simplify an expression, the metric might overestimate the accuracy. This is because if the model doesn't output any anchor for the extraction (e.g final answer is..),
200+
it's possible that the the extracted prediction will be the expression to simplify. Because we do simplifications ourselves, it can thus happen that sympy will correctly simplify the expression,
201+
thus it will match gold, despite model not doing anything. PRs to fix this are welcome.
202+
203+
- There is currently no StringExtractionConfig, so if the gold is \boxed{\text{Friday}} and model outputs Friday it will not match, because nothing will be extracted.
204+
205+
Args:
206+
language: Language
207+
The language of the samples.
208+
gold_extraction_target: Sequence[ExtractionTarget]
209+
Extraction targets to use for gold answers. Defaults to extracting simple math expressions.
210+
pred_extraction_target: Sequence[ExtractionTarget]
211+
Extraction targets to use for predictions. Defaults to extracting simple math expressions.
212+
aggregation_function: Callable[[list[float]], float]
213+
Function to aggregate scores when multiple golds/predictions are present. Defaults to max.
214+
fallback_mode: Literal["no_fallback", "first_match"]
215+
How to perform extraction. Defaults to "first_match".
216+
- "no_fallback": Only use first successfully parsed matches
217+
- "first_match": Use the first successfully parsed match + first match irregardless the parsing success
218+
precision: int
219+
Number of decimal places to use when comparing numerical values. Defaults to 6.
220+
221+
Returns:
222+
A sample level metric that extracts and compares mathematical expressions.
223+
224+
"""
225+
226+
@timeout(2)
227+
def add_to_specifics_with_timeout(
228+
formatted_doc: Doc, extracted_predictions: list[list[str]], extracted_golds: list[list[str]]
229+
) -> None:
230+
if formatted_doc.specific is None:
231+
formatted_doc.specific = {}
232+
233+
formatted_doc.specific["extracted_predictions"] = [
234+
str(pred) for preds in extracted_predictions for pred in preds
235+
]
236+
formatted_doc.specific["extracted_golds"] = [str(gold) for golds in extracted_golds for gold in golds]
237+
238+
def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc) -> float:
239+
gold_extraction_regexes = get_extraction_regexes(formatted_doc, gold_extraction_target, language)
240+
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)
241+
242+
extracted_predictions = [
243+
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode) for pred in predictions
244+
]
245+
extracted_golds = [extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode) for gold in golds]
246+
247+
# Assert on empty gold and warn on empty pred
248+
if any(len(g) == 0 for g in extracted_golds):
249+
raise ValueError(f"No gold targets found for at least one gold. Gold: {golds}, Pred: {predictions}")
250+
251+
if all(len(p) == 0 for p in extracted_predictions):
252+
logger.warning(
253+
f"We did not manage to extract a prediction in the correct format. Gold: {golds}, Pred: {predictions}"
254+
)
255+
256+
# We have to use timeout because the sypmy to str conversion can be very slow
257+
try:
258+
add_to_specifics_with_timeout(formatted_doc, extracted_predictions, extracted_golds)
259+
except: # noqa: E722
260+
logger.warning("Timeout when adding extracted predictions and golds to specific")
261+
262+
return aggregation_function(
263+
[
264+
(1.0 if any(compare_gold_target(gold, pred, precision) for gold in extracted_golds) else 0.0)
265+
for pred in extracted_predictions
266+
]
267+
)
268+
269+
return SampleLevelMetric(
270+
metric_name="extractive_match",
271+
sample_level_fn=sample_level_fn,
272+
category=MetricCategory.GENERATIVE,
273+
use_case=MetricUseCase.ACCURACY,
274+
corpus_level_fn=np.mean,
275+
higher_is_better=True,
276+
)

0 commit comments

Comments
 (0)