Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions nemo_skills/dataset/mmmlu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

METRICS_TYPE = "multichoice"
GENERATION_ARGS = "++prompt_config=generic/default ++eval_type=multichoice"
194 changes: 194 additions & 0 deletions nemo_skills/dataset/mmmlu/mmmlu_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import urllib.request
from pathlib import Path

import pandas

SUPPORTED_LANGUAGES = [
"AR-XY", # Arabic
"BN-BD", # Bengali
"DE-DE", # German
"ES-LA", # Spanish
"FR-FR", # French
"HI-IN", # Hindi
"ID-ID", # Indonesian
"IT-IT", # Italian
"JA-JP", # Japanese
"KO-KR", # Korean
"PT-BR", # Portuguese
"ZH-CN", # Chinese
"SW-KE", # Swahili
"YO-NG", # Yoruba
]

subject2category = {
"abstract_algebra": "stem",
"anatomy": "other",
"astronomy": "stem",
"business_ethics": "other",
"clinical_knowledge": "other",
"college_biology": "stem",
"college_chemistry": "stem",
"college_computer_science": "stem",
"college_mathematics": "stem",
"college_medicine": "other",
"college_physics": "stem",
"computer_security": "stem",
"conceptual_physics": "stem",
"econometrics": "social_sciences",
"electrical_engineering": "stem",
"elementary_mathematics": "stem",
"formal_logic": "humanities",
"global_facts": "other",
"high_school_biology": "stem",
"high_school_chemistry": "stem",
"high_school_computer_science": "stem",
"high_school_european_history": "humanities",
"high_school_geography": "social_sciences",
"high_school_government_and_politics": "social_sciences",
"high_school_macroeconomics": "social_sciences",
"high_school_mathematics": "stem",
"high_school_microeconomics": "social_sciences",
"high_school_physics": "stem",
"high_school_psychology": "social_sciences",
"high_school_statistics": "stem",
"high_school_us_history": "humanities",
"high_school_world_history": "humanities",
"human_aging": "other",
"human_sexuality": "social_sciences",
"international_law": "humanities",
"jurisprudence": "humanities",
"logical_fallacies": "humanities",
"machine_learning": "stem",
"management": "other",
"marketing": "other",
"medical_genetics": "other",
"miscellaneous": "other",
"moral_disputes": "humanities",
"moral_scenarios": "humanities",
"nutrition": "other",
"philosophy": "humanities",
"prehistory": "humanities",
"professional_accounting": "other",
"professional_law": "humanities",
"professional_medicine": "other",
"professional_psychology": "social_sciences",
"public_relations": "social_sciences",
"security_studies": "social_sciences",
"sociology": "social_sciences",
"us_foreign_policy": "social_sciences",
"virology": "other",
"world_religions": "humanities",
}

QUERY_TEMPLATE_MULTICHOICE = """
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.

{Question}

A) {A}
B) {B}
C) {C}
D) {D}
""".strip()

MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = "(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

[أ-د] over-matches unintended Arabic letters.

This is a range, not just option letters. It can capture non-option characters and produce wrong labels.

💡 Proposed fix
-MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = "(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
+MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = "(?i){}[ \t]*([A-D]|[أبجد]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = "(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = "(?i){}[ \t]*([A-D]|[أبجد]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 110-110: String contains ambiguous (FULLWIDTH LATIN CAPITAL LETTER A). Did you mean A (LATIN CAPITAL LETTER A)?

(RUF001)


[warning] 110-110: String contains ambiguous (FULLWIDTH LATIN CAPITAL LETTER B). Did you mean B (LATIN CAPITAL LETTER B)?

(RUF001)


[warning] 110-110: String contains ambiguous (FULLWIDTH LATIN CAPITAL LETTER C). Did you mean C (LATIN CAPITAL LETTER C)?

(RUF001)


[warning] 110-110: String contains ambiguous (FULLWIDTH LATIN CAPITAL LETTER D). Did you mean D (LATIN CAPITAL LETTER D)?

(RUF001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_skills/dataset/mmmlu/mmmlu_utils.py` at line 110,
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE currently uses the range [أ-د] which
over-matches Arabic characters; update the regex to list the exact Arabic option
letters instead of a range (e.g. replace [أ-د] with [أبجد]) so only the intended
letters map to A-D; ensure the final pattern for
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE explicitly enumerates each script's option
letters (e.g., [A-D]|[أبجد]|[অবডঢ]|[ABCD]) rather than using character ranges.

# All the different ways "Answer" is written in different languages
MULTILINGUAL_ANSWER_REGEXES = [
"Answer\s*:",
"Answer\s*:​​​​​​", # Korean invisible character
"উত্তর\s*:",
"उत्तर\s*:",
"উত্তরঃ",
"উত্তর\s*:",
"Antwort\s*:",
"답변\s*:",
"정답\s*:",
"답\s*:",
"答案\s*:",
"答案\s*:",
"答\s*:",
"答\s*:",
"答复\s*:",
"答曰\s*:",
"الإجابة:",
"الجواب:",
"إجابة:",
"الإجابة النهائية:",
"الإجابة الصحيحة:",
"الإجابة الصحيحة هي:",
"الإجابة هي:",
"الجواب النهائي:",
"Respuesta\s*:",
"Risposta\s*:",
"答え\s*:",
"答え\s*:",
"回答\s*:",
"回答\s*:",
"解答\s*:",
"Jawaban\s*:",
"Réponse\s*:",
"Resposta\s*:",
"Jibu\s*:",
"Idahun\s*:",
"Ìdáhùn\s*:",
"Idáhùn\s*:",
"Àmọ̀nà\s*:",
"Àdáhùn\s*:",
"Ànúgọ\s*:",
"Àṣàyàn\s*:",
]


class Schema:
ANSWER: str = "Answer"
QUESTION: str = "Question"
SUBJECT: str = "Subject"
OPTIONS: list[str] = ["A", "B", "C", "D"]


def download_mmmlu_datasets(languages: list[str]) -> dict[str, list[dict]]:
OPENAI_PUBLIC_URL = "https://openaipublic.blob.core.windows.net/simple-evals/{}"
data_dir = Path(__file__).absolute().parent
mmmlu_datasets = {}
for language in languages:
suffix = "mmlu.csv" if language == "EN-US" else f"mmlu_{language}.csv"
download_dst_path = data_dir / suffix
if os.path.exists(download_dst_path):
print(f"Skipping download of {suffix} because it already exists")
else:
url = OPENAI_PUBLIC_URL.format(suffix)
urllib.request.urlretrieve(url, download_dst_path)
if not os.path.exists(download_dst_path):
raise RuntimeError(f"Failed to download {suffix}")

df = pandas.read_csv(download_dst_path, index_col=0)
examples = [row.to_dict() for _, row in df.iterrows()]
mmmlu_datasets[language] = examples
return mmmlu_datasets


def format_multichoice_question(row):
return QUERY_TEMPLATE_MULTICHOICE.format(**row)


def get_mcq_fields(entry: dict):
options_dict = {letter: entry[letter] for letter in Schema.OPTIONS}
options_text = "\n".join(f"{letter}) {option}" for letter, option in options_dict.items())
prompt = format_multichoice_question(entry)
return {"question": prompt, "options": options_text, **options_dict}
86 changes: 86 additions & 0 deletions nemo_skills/dataset/mmmlu/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
from pathlib import Path

from nemo_skills.dataset.mmmlu.mmmlu_utils import (
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE,
MULTILINGUAL_ANSWER_REGEXES,
SUPPORTED_LANGUAGES,
Schema,
download_mmmlu_datasets,
get_mcq_fields,
subject2category,
)


def format_entry(entry: dict, language: str) -> dict:
expected_answer = entry[Schema.ANSWER]
category = subject2category.get(entry[Schema.SUBJECT], "other")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid silent fallback for missing subject-category mappings.

Line 31 silently maps unknown subjects to "other", which can hide dataset/schema drift and corrupt category metrics.

💡 Proposed fix
-    category = subject2category.get(entry[Schema.SUBJECT], "other")
+    subject = entry[Schema.SUBJECT]
+    category = subject2category[subject]

As per coding guidelines: "Don't use .get() for accessing dictionary keys if the code expects them to be present; use direct access data[key_name] to fail with a clear error instead of silently corrupting data".

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
category = subject2category.get(entry[Schema.SUBJECT], "other")
subject = entry[Schema.SUBJECT]
category = subject2category[subject]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_skills/dataset/mmmlu/prepare.py` at line 31, The current code silently
maps unknown subjects to "other" using
subject2category.get(entry[Schema.SUBJECT], "other"), which hides missing
mappings; change this to explicitly require the mapping by accessing
subject2category[entry[Schema.SUBJECT]] (or validate with an if and raise a
clear ValueError referencing the missing subject) so that missing subject keys
surface immediately; update the assignment to category and add a descriptive
error message mentioning entry[Schema.SUBJECT] and subject2category to aid
debugging.

regexes = [
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex) for answer_regex in MULTILINGUAL_ANSWER_REGEXES
]
LETTER_REGEX = r"\b\(?\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])\s*\)?\.?\b"
GREEDY_REGEX = r"[\s\S]*" + LETTER_REGEX
regexes.append(GREEDY_REGEX) # Matches the last A/B/C/D letter in the response
Comment on lines +36 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use explicit Arabic option letters instead of [أ-د] range.

[أ-د] matches more than the intended four option letters and can mis-extract answers.

💡 Proposed fix
-    LETTER_REGEX = r"\b\(?\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])\s*\)?\.?\b"
+    LETTER_REGEX = r"\b\(?\s*([A-D]|[أبجد]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])\s*\)?\.?\b"
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 36-36: String contains ambiguous (FULLWIDTH LATIN CAPITAL LETTER A). Did you mean A (LATIN CAPITAL LETTER A)?

(RUF001)


[warning] 36-36: String contains ambiguous (FULLWIDTH LATIN CAPITAL LETTER B). Did you mean B (LATIN CAPITAL LETTER B)?

(RUF001)


[warning] 36-36: String contains ambiguous (FULLWIDTH LATIN CAPITAL LETTER C). Did you mean C (LATIN CAPITAL LETTER C)?

(RUF001)


[warning] 36-36: String contains ambiguous (FULLWIDTH LATIN CAPITAL LETTER D). Did you mean D (LATIN CAPITAL LETTER D)?

(RUF001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_skills/dataset/mmmlu/prepare.py` around lines 36 - 38, LETTER_REGEX uses
the Arabic range `[أ-د]` which matches unintended characters; change it to list
the four Arabic option letters explicitly (e.g. use `[أبجد]` or include them as
alternatives) inside LETTER_REGEX so only the intended choices are matched, then
keep GREEDY_REGEX and regexes.append(GREEDY_REGEX) as-is; update the
LETTER_REGEX definition (the symbol name is LETTER_REGEX) to use explicit Arabic
letters instead of the range.

return {
"expected_answer": expected_answer,
"extract_from_boxed": False,
"extract_regex": regexes,
"subset_for_metrics": language,
"relaxed": False,
"category": category,
**get_mcq_fields(entry),
}


def main(args):
languages = [lang for lang in args.languages if lang != "EN-US"]
valid_languages = set(SUPPORTED_LANGUAGES)
if args.include_english:
valid_languages.add("EN-US")
languages.append("EN-US")
Comment on lines +51 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Do not silently ignore explicit EN-US input.

Line 53 removes EN-US from user input when --include_english is not set, so a user-passed language can be dropped without error.

💡 Proposed fix
 def main(args):
-    languages = [lang for lang in args.languages if lang != "EN-US"]
+    if "EN-US" in args.languages and not args.include_english:
+        raise ValueError("EN-US requires --include_english.")
+
+    languages = [lang for lang in args.languages if lang != "EN-US"]
     valid_languages = set(SUPPORTED_LANGUAGES)
     if args.include_english:
         valid_languages.add("EN-US")
         languages.append("EN-US")

As per coding guidelines: "Avoid cases where user-passed parameters are unused; code should fail if user specifies an unsupported argument or if a required argument is missing."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_skills/dataset/mmmlu/prepare.py` around lines 53 - 57, The code
currently silently removes "EN-US" from args.languages when args.include_english
is false; change this so user-provided "EN-US" is not ignored: keep languages
built from args.languages (do not pre-filter out "EN-US"), validate each
requested language against SUPPORTED_LANGUAGES, and if "EN-US" is requested but
args.include_english is False raise an error (or exit) telling the user to pass
--include_english; also if any requested language is not in SUPPORTED_LANGUAGES
raise an error. Modify the logic around languages, valid_languages,
args.languages, args.include_english and SUPPORTED_LANGUAGES accordingly.


invalid = set(languages) - valid_languages
if invalid:
raise ValueError(f"Unsupported languages: {invalid}")
datasets = download_mmmlu_datasets(languages)

data_dir = Path(__file__).absolute().parent
output_file = data_dir / "test.jsonl"
with open(output_file, "wt", encoding="utf-8") as fout:
for language, examples in datasets.items():
for entry in examples:
entry = format_entry(entry=entry, language=language)
json.dump(entry, fout, ensure_ascii=False)
fout.write("\n")
Comment on lines +64 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Compute all entries before truncating test.jsonl.

Lines 66-71 perform formatting while writing to a truncate-opened file. If formatting fails mid-run, existing output is lost and only a partial file remains.

💡 Proposed fix
-    with open(output_file, "wt", encoding="utf-8") as fout:
-        for language, examples in datasets.items():
-            for entry in examples:
-                entry = format_entry(entry=entry, language=language)
-                json.dump(entry, fout, ensure_ascii=False)
-                fout.write("\n")
+    formatted_entries = []
+    for language, examples in datasets.items():
+        for entry in examples:
+            formatted_entries.append(format_entry(entry=entry, language=language))
+
+    with open(output_file, "wt", encoding="utf-8") as fout:
+        for entry in formatted_entries:
+            json.dump(entry, fout, ensure_ascii=False)
+            fout.write("\n")

As per coding guidelines: "When adding new benchmarks, avoid data loss by doing all computation before re-opening files for writing; ensure computation completes before file writes to prevent accidental data loss if code fails".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_skills/dataset/mmmlu/prepare.py` around lines 66 - 71, The current loop
formats and writes entries directly into the truncate-opened output_file which
risks data loss if formatting errors occur; instead, first iterate datasets and
call format_entry(entry=..., language=...) for every entry, collect the
resulting JSON-serializable objects (e.g., into a list like formatted_entries),
and only after that successfully completes open output_file for writing ("wt")
and write the collected entries with json.dump and fout.write("\n"); reference
format_entry, datasets, and output_file to locate the code to change.



if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--languages",
default=SUPPORTED_LANGUAGES,
nargs="+",
help="Languages to process.",
)
parser.add_argument(
"--include_english",
action="store_true",
help="Include English split which corresponds to the original MMLU dataset.",
)
args = parser.parse_args()
main(args)
52 changes: 44 additions & 8 deletions nemo_skills/evaluation/evaluator/mcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,32 @@
LOG = logging.getLogger(get_logger_name(__file__))


def normalize_extracted_answer(extracted_answer: str) -> str:
return (
# In arabic these are the letters used for A-D in multiple choice questions
extracted_answer.replace("أ", " A")
.replace("ب", " B")
.replace("ج", " C")
.replace("د", " D")
# In Bengali these are the letters used for A-D in multiple choice questions
.replace("অ", " A")
.replace("ব", " B")
.replace("ড", " C")
.replace("ঢ", " D")
# In Japanese these are the letters sometimes used for A-D in multiple choice questions
.replace("A", " A")
.replace("B", " B")
.replace("C", " C")
.replace("D", " D")
.strip()
)
Comment on lines +28 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid global character replacement across the full extracted text.

This replaces Arabic/Bengali letters everywhere in extracted_answer, so longer answers can gain synthetic A/B/C/D tokens and be misgraded.

💡 Proposed fix
+OPTION_CHAR_MAP = {
+    "A": "A",
+    "B": "B",
+    "C": "C",
+    "D": "D",
+    "أ": "A",
+    "ب": "B",
+    "ج": "C",
+    "د": "D",
+    "অ": "A",
+    "ব": "B",
+    "ড": "C",
+    "ঢ": "D",
+    "A": "A",
+    "B": "B",
+    "C": "C",
+    "D": "D",
+}
+
 def normalize_extracted_answer(extracted_answer: str) -> str:
-    return (
-        # In arabic these are the letters used for A-D in multiple choice questions
-        extracted_answer.replace("أ", " A")
-        .replace("ب", " B")
-        .replace("ج", " C")
-        .replace("د", " D")
-        # In Bengali these are the letters used for A-D in multiple choice questions
-        .replace("অ", " A")
-        .replace("ব", " B")
-        .replace("ড", " C")
-        .replace("ঢ", " D")
-        # In Japanese these are the letters sometimes used for A-D in multiple choice questions
-        .replace("A", " A")
-        .replace("B", " B")
-        .replace("C", " C")
-        .replace("D", " D")
-        .strip()
-    )
+    normalized = extracted_answer.strip()
+    match = re.fullmatch(r"[\s\(\[\{<\*_]*([A-DأبجدঅবডঢABCD])[\s\)\]\}>.\*_]*", normalized)
+    if not match:
+        return normalized
+    return OPTION_CHAR_MAP[match.group(1)]
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 41-41: String contains ambiguous (FULLWIDTH LATIN CAPITAL LETTER A). Did you mean A (LATIN CAPITAL LETTER A)?

(RUF001)


[warning] 42-42: String contains ambiguous (FULLWIDTH LATIN CAPITAL LETTER B). Did you mean B (LATIN CAPITAL LETTER B)?

(RUF001)


[warning] 43-43: String contains ambiguous (FULLWIDTH LATIN CAPITAL LETTER C). Did you mean C (LATIN CAPITAL LETTER C)?

(RUF001)


[warning] 44-44: String contains ambiguous (FULLWIDTH LATIN CAPITAL LETTER D). Did you mean D (LATIN CAPITAL LETTER D)?

(RUF001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_skills/evaluation/evaluator/mcq.py` around lines 28 - 46, The current
normalize_extracted_answer function performs blind global .replace() calls on
extracted_answer which can turn characters inside longer answers into synthetic
" A/B/C/D" tokens; change it to only map standalone choice markers or markers at
the start/end of the answer (e.g., when the entire string equals an
Arabic/Bengali/Japanese marker, when the marker is surrounded by
whitespace/punctuation, or when it appears as the first token) — update
normalize_extracted_answer to use targeted matching (regex or token-based
checks) on extracted_answer so only isolated choice markers are converted to
"A"/"B"/"C"/"D" and embedded characters in longer answers are left unchanged.



@nested_dataclass(kw_only=True)
class MCQEvaluatorConfig(BaseEvaluatorConfig):
extract_from_boxed: bool = True
# only used if extract_from_boxed is False
extract_regex: str = r"The final answer is (.+)$"
extract_regex: str | list[str] = r"The final answer is (.+)$"
# if relaxed is True:
# extract from regex FIRST, if not found, extract from boxed
# if relaxed is False:
Expand All @@ -42,22 +63,37 @@ def eval_mcq(cfg):
eval_config = MCQEvaluatorConfig(**cfg)

def extract_letter(
text, extract_from_boxed: bool = True, extract_regex: str = r"The final answer is (.+)$", relaxed=False
text,
extract_from_boxed: bool = True,
extract_regex: str | list[str] = r"The final answer is (.+)$",
relaxed=False,
):
# extract prediction from boxed{} or regex
extracted_answer = extract_answer(
text, extract_from_boxed=extract_from_boxed, extract_regex=extract_regex, relaxed=relaxed
)
parsed_letter = None
extracted_answer = None
if isinstance(extract_regex, list):
for regex in extract_regex:
extracted_answer = extract_answer(
text, extract_from_boxed=extract_from_boxed, extract_regex=regex, relaxed=relaxed
)
if extracted_answer is not None:
break
else:
extracted_answer = extract_answer(
text, extract_from_boxed=extract_from_boxed, extract_regex=extract_regex, relaxed=relaxed
)

if extracted_answer is not None:
extracted_answer = normalize_extracted_answer(extracted_answer)

parsed_letter = None
if extracted_answer is not None:
if len(extracted_answer) == 1:
parsed_letter = extracted_answer
parsed_letter = extracted_answer.upper()
elif len(extracted_answer) > 1:
# try to extract the letter from extracted answer, useful to match <A>, {A}, *A*, etc.
match = re.findall(r"\b[A-Z]\b(?!.*\b[A-Z]\b)", extracted_answer, re.DOTALL)
if len(match) > 0:
parsed_letter = match[-1].strip()
parsed_letter = match[-1].strip().upper()

# adapted from https://artificialanalysis.ai/methodology/intelligence-benchmarking#intelligence-index-evaluation-suite-overview
if parsed_letter is None:
Expand Down