2121# SOFTWARE. 
2222
2323import  re 
24- from  dataclasses  import  dataclass 
24+ from  dataclasses  import  dataclass ,  field 
2525from  functools  import  lru_cache 
2626from  itertools  import  groupby 
27- from  typing  import  Literal , Sequence 
27+ from  typing  import  Any ,  Literal , Sequence 
2828
2929import  sympy 
3030from  sympy  import  Basic , MatrixBase , Number 
3939from  lighteval .utils .timeout  import  timeout 
4040
4141
42+ @requires_latex2sympy2_extended  
43+ def  latex_normalization_config_default_factory ():
44+     from  latex2sympy2_extended .latex2sympy2  import  NormalizationConfig 
45+ 
46+     return  NormalizationConfig (
47+         basic_latex = True ,
48+         units = True ,
49+         malformed_operators = True ,
50+         nits = True ,
51+         boxed = True ,
52+         equations = True ,
53+     )
54+ 
55+ 
4256@dataclass (frozen = True ) 
4357class  LatexExtractionConfig :
4458    """Config for extracting latex from the prediction. 
4559
4660    Attributes: 
4761        try_extract_without_anchor (bool): Whether to try extracting latex without requiring specific anchors like "answer:" or "final answer is" 
48-         enforce_boxed_match (bool): Whether to also consider extracting from plain \b oxed{...} expressions 
62+         boxed_match_priority (int): Priority of the boxed match regex (-1 never, 0 first, 55 after final answer: anchor, etc...) 
63+         normalization_config (latex2sympy2_extended.latex2sympy2.NormalizationConfig): Normalization config to use for latex extraction 
4964    """ 
5065
5166    try_extract_without_anchor : bool  =  True 
52-     enforce_boxed_match : bool  =  True 
67+     boxed_match_priority : int  =  55 
68+     normalization_config : Any  =  field (default_factory = latex_normalization_config_default_factory )
5369
5470
5571@dataclass (frozen = True ) 
@@ -187,9 +203,8 @@ def lazy_latex_regex(latex_config: LatexExtractionConfig, language: Language) ->
187203        if  latex_config .try_extract_without_anchor :
188204            regexes .append ((latex_re , 300 ))
189205
190-     # This ensures that boxed is matched right after the final answer xxxx 
191-     if  latex_config .enforce_boxed_match :
192-         regexes .append ((latex_boxed , 55 ))
206+     if  latex_config .boxed_match_priority  >=  0 :
207+         regexes .append ((latex_boxed , latex_config .boxed_match_priority ))
193208
194209    return  [(re .compile (pattern , re .DOTALL ), priority ) for  pattern , priority  in  regexes ]
195210
@@ -387,6 +402,7 @@ def extract_target_from_pred(
387402    pred : str ,
388403    target_res : list [tuple [list [tuple [re .Pattern [str ], int ]], ExtractionTarget ]],
389404    fallback_mode : Literal ["no_fallback" , "first_match" ] =  "no_fallback" ,
405+     extraction_mode : Literal ["first_match" , "any_match" ] =  "any_match" ,
390406):
391407    """Extracts targets from a prediction string using regex patterns. 
392408    Returns first sucesffuly extracted match. 
@@ -397,6 +413,9 @@ def extract_target_from_pred(
397413        fallback_mode (Literal["no_fallback", "first_match"], optional): How to handle extraction failures. Defaults to "no_fallback". 
398414            - "no_fallback": Return only successfully parsed match 
399415            - "first_match": Additionaly Include the first string match no matter how parsing finished 
416+         extraction_mode (Literal["first_match", "any_match"], optional): How to handle extraction failures. Defaults to "any_match". 
417+             - "first_match": Only tries to extract the first match 
418+             - "any_match": Tries to extract any match 
400419
401420    Returns: 
402421        list: List of extracted predictions, with first fallbac string appended if fallback_mode is "first_match" 
@@ -410,6 +429,7 @@ def extract_target_from_pred(
410429        for  target_patterns , target_type  in  target_res 
411430        for  pattern , priority  in  target_patterns 
412431    ]
432+     match_found  =  False 
413433
414434    # Group patterns by priority using itertools.groupby 
415435    for  _ , patterns_group  in  groupby (sorted (all_patterns , key = lambda  x : x [2 ]), key = lambda  x : x [2 ]):
@@ -426,6 +446,7 @@ def extract_target_from_pred(
426446        # Try to extract from each match, starting from rightmost 
427447        for  match , _ , _ , target_type  in  matches_with_pos :
428448            extracted_match , str_fallback  =  extract_match (match , target_type )
449+             match_found  =  True 
429450
430451            if  str_fallback :
431452                fallbacks .append (str_fallback )
@@ -434,8 +455,11 @@ def extract_target_from_pred(
434455                extracted_predictions .append (extracted_match )
435456                break 
436457
458+             if  extraction_mode  ==  "first_match" :
459+                 break 
460+ 
437461        # If we found something and we're in first_match mode, stop processing other priorities 
438-         if  extracted_predictions :
462+         if  extracted_predictions   or  ( match_found   and   extraction_mode   ==   "first_match" ) :
439463            break 
440464
441465    if  fallback_mode  ==  "first_match"  and  fallbacks :
0 commit comments