diff --git a/nemo_skills/evaluation/evaluator/mcq.py b/nemo_skills/evaluation/evaluator/mcq.py index 522aeca8c3..e057fb3a12 100644 --- a/nemo_skills/evaluation/evaluator/mcq.py +++ b/nemo_skills/evaluation/evaluator/mcq.py @@ -19,34 +19,56 @@ from tqdm import tqdm from nemo_skills.evaluation.math_grader import extract_answer -from nemo_skills.utils import get_logger_name, unroll_files +from nemo_skills.utils import get_logger_name, nested_dataclass, unroll_files LOG = logging.getLogger(get_logger_name(__file__)) +@nested_dataclass(kw_only=True) +class MCQEvaluatorConfig: + extract_from_boxed: bool = True + # only used if extract_from_boxed is False + extract_regex: str = r"The final answer is (.+)$" + + def eval_mcq(cfg): + # Create config from cfg.eval_config (following pattern from other evaluators) + eval_config = MCQEvaluatorConfig(**cfg.eval_config) + def extract_letter(text, extract_from_boxed: bool = True, extract_regex: str = r"The final answer is (.+)$"): - # extract prediction from boxed{} - parsed = extract_answer(text, extract_from_boxed=extract_from_boxed, extract_regex=extract_regex) - if parsed is not None and len(parsed) != 1: - match = re.findall(r"\b[A-J]\b(?!.*\b[A-J]\b)", parsed, re.DOTALL) - if len(match) > 0: - parsed = match[-1].strip() + # extract prediction from boxed{} or regex + extracted_answer = extract_answer(text, extract_from_boxed=extract_from_boxed, extract_regex=extract_regex) + parsed_letter = None + + if extracted_answer is not None: + if len(extracted_answer) == 1: + parsed_letter = extracted_answer + elif len(extracted_answer) > 1: + # try to extract the letter from extracted answer, useful to match , {A}, *A*, etc. + match = re.findall(r"\b[A-Z]\b(?!.*\b[A-Z]\b)", extracted_answer, re.DOTALL) + if len(match) > 0: + parsed_letter = match[-1].strip() # adapted from https://artificialanalysis.ai/methodology/intelligence-benchmarking#intelligence-index-evaluation-suite-overview - if parsed is None: + if parsed_letter is None: match = re.findall(r"(?i)[\*\_]{0,2}Answer[\*\_]{0,2}\s*:[\s\*\_]{0,2}\s*([A-Z])(?![a-zA-Z0-9])", text) if match: - parsed = match[-1].strip() - return parsed + parsed_letter = match[-1].strip().upper() + + LOG.info( + f"Final parsed letter: {parsed_letter}, extract_from_boxed: {extract_from_boxed}, extract_regex: {extract_regex}, extracted_answer: {extracted_answer}" + ) + + return parsed_letter for file in unroll_files(cfg.input_files): with open(file, "rt", encoding="utf-8") as fin: data = [json.loads(line) for line in fin] with open(file, "wt", encoding="utf-8") as fout: for sample in tqdm(data): - extract_from_boxed = sample.get("extract_from_boxed", True) - extract_regex = sample.get("extract_regex", r"The final answer is (.+)$") + # Per-sample values override config defaults for backward compatibility + extract_from_boxed = sample.get("extract_from_boxed", eval_config.extract_from_boxed) + extract_regex = sample.get("extract_regex", eval_config.extract_regex) sample["predicted_answer"] = extract_letter( sample["generation"], extract_from_boxed=extract_from_boxed, extract_regex=extract_regex )