@@ -193,6 +193,7 @@ def multilingual_extractive_match_metric(
193
193
fallback_mode : Literal ["no_fallback" , "first_match" ] = "first_match" ,
194
194
extraction_mode : Literal ["first_match" , "any_match" ] = "any_match" ,
195
195
precision : int = 6 ,
196
+ timeout_seconds : int = 5 ,
196
197
) -> SampleLevelMetric :
197
198
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
198
199
@@ -222,6 +223,8 @@ def multilingual_extractive_match_metric(
222
223
223
224
precision: int
224
225
Number of decimal places to use when comparing numerical values. Defaults to 6.
226
+ timeout_seconds: int
227
+ Timeout for the extraction (each attempt) and comparison. Defaults to 5.
225
228
226
229
Returns:
227
230
A sample level metric that extracts and compares mathematical expressions.
@@ -245,11 +248,12 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
245
248
pred_extraction_regexes = get_extraction_regexes (formatted_doc , pred_extraction_target , language )
246
249
247
250
extracted_predictions = [
248
- extract_target_from_pred (pred , pred_extraction_regexes , fallback_mode , extraction_mode )
251
+ extract_target_from_pred (pred , pred_extraction_regexes , fallback_mode , extraction_mode , timeout_seconds )
249
252
for pred in predictions
250
253
]
251
254
extracted_golds = [
252
- extract_target_from_pred (gold , gold_extraction_regexes , fallback_mode , extraction_mode ) for gold in golds
255
+ extract_target_from_pred (gold , gold_extraction_regexes , fallback_mode , extraction_mode , timeout_seconds )
256
+ for gold in golds
253
257
]
254
258
255
259
# Assert on empty gold and warn on empty pred
@@ -265,12 +269,19 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
265
269
# We have to use timeout because the sypmy to str conversion can be very slow
266
270
try :
267
271
add_to_specifics_with_timeout (formatted_doc , extracted_predictions , extracted_golds )
268
- except : # noqa: E722
272
+ except Exception : # noqa: E722
269
273
logger .warning ("Timeout when adding extracted predictions and golds to specific" )
270
274
271
275
return aggregation_function (
272
276
[
273
- (1.0 if any (compare_gold_target (gold , pred , precision ) for gold in extracted_golds ) else 0.0 )
277
+ (
278
+ 1.0
279
+ if any (
280
+ compare_gold_target (gold , pred , precision , timeout_seconds = timeout_seconds )
281
+ for gold in extracted_golds
282
+ )
283
+ else 0.0
284
+ )
274
285
for pred in extracted_predictions
275
286
]
276
287
)
0 commit comments