diff --git a/nemo_skills/dataset/mmmlu/__init__.py b/nemo_skills/dataset/mmmlu/__init__.py new file mode 100644 index 0000000000..7eb573db87 --- /dev/null +++ b/nemo_skills/dataset/mmmlu/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +METRICS_TYPE = "multichoice" +GENERATION_ARGS = "++prompt_config=generic/default ++eval_type=multichoice" diff --git a/nemo_skills/dataset/mmmlu/mmmlu_utils.py b/nemo_skills/dataset/mmmlu/mmmlu_utils.py new file mode 100644 index 0000000000..513040433f --- /dev/null +++ b/nemo_skills/dataset/mmmlu/mmmlu_utils.py @@ -0,0 +1,194 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import urllib.request +from pathlib import Path + +import pandas + +SUPPORTED_LANGUAGES = [ + "AR-XY", # Arabic + "BN-BD", # Bengali + "DE-DE", # German + "ES-LA", # Spanish + "FR-FR", # French + "HI-IN", # Hindi + "ID-ID", # Indonesian + "IT-IT", # Italian + "JA-JP", # Japanese + "KO-KR", # Korean + "PT-BR", # Portuguese + "ZH-CN", # Chinese + "SW-KE", # Swahili + "YO-NG", # Yoruba +] + +subject2category = { + "abstract_algebra": "stem", + "anatomy": "other", + "astronomy": "stem", + "business_ethics": "other", + "clinical_knowledge": "other", + "college_biology": "stem", + "college_chemistry": "stem", + "college_computer_science": "stem", + "college_mathematics": "stem", + "college_medicine": "other", + "college_physics": "stem", + "computer_security": "stem", + "conceptual_physics": "stem", + "econometrics": "social_sciences", + "electrical_engineering": "stem", + "elementary_mathematics": "stem", + "formal_logic": "humanities", + "global_facts": "other", + "high_school_biology": "stem", + "high_school_chemistry": "stem", + "high_school_computer_science": "stem", + "high_school_european_history": "humanities", + "high_school_geography": "social_sciences", + "high_school_government_and_politics": "social_sciences", + "high_school_macroeconomics": "social_sciences", + "high_school_mathematics": "stem", + "high_school_microeconomics": "social_sciences", + "high_school_physics": "stem", + "high_school_psychology": "social_sciences", + "high_school_statistics": "stem", + "high_school_us_history": "humanities", + "high_school_world_history": "humanities", + "human_aging": "other", + "human_sexuality": "social_sciences", + "international_law": "humanities", + "jurisprudence": "humanities", + "logical_fallacies": "humanities", + "machine_learning": "stem", + "management": "other", + "marketing": "other", + "medical_genetics": "other", + "miscellaneous": "other", + "moral_disputes": "humanities", + "moral_scenarios": "humanities", + "nutrition": "other", + "philosophy": "humanities", + "prehistory": "humanities", + "professional_accounting": "other", + "professional_law": "humanities", + "professional_medicine": "other", + "professional_psychology": "social_sciences", + "public_relations": "social_sciences", + "security_studies": "social_sciences", + "sociology": "social_sciences", + "us_foreign_policy": "social_sciences", + "virology": "other", + "world_religions": "humanities", +} + +QUERY_TEMPLATE_MULTICHOICE = """ +Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. + +{Question} + +A) {A} +B) {B} +C) {C} +D) {D} +""".strip() + +MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = "(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])" +# All the different ways "Answer" is written in different languages +MULTILINGUAL_ANSWER_REGEXES = [ + "Answer\s*:", + "Answer\s*:​​​​​​", # Korean invisible character + "উত্তর\s*:", + "उत्तर\s*:", + "উত্তরঃ", + "উত্তর\s*:", + "Antwort\s*:", + "답변\s*:", + "정답\s*:", + "답\s*:", + "答案\s*:", + "答案\s*:", + "答\s*:", + "答\s*:", + "答复\s*:", + "答曰\s*:", + "الإجابة:", + "الجواب:", + "إجابة:", + "الإجابة النهائية:", + "الإجابة الصحيحة:", + "الإجابة الصحيحة هي:", + "الإجابة هي:", + "الجواب النهائي:", + "Respuesta\s*:", + "Risposta\s*:", + "答え\s*:", + "答え\s*:", + "回答\s*:", + "回答\s*:", + "解答\s*:", + "Jawaban\s*:", + "Réponse\s*:", + "Resposta\s*:", + "Jibu\s*:", + "Idahun\s*:", + "Ìdáhùn\s*:", + "Idáhùn\s*:", + "Àmọ̀nà\s*:", + "Àdáhùn\s*:", + "Ànúgọ\s*:", + "Àṣàyàn\s*:", +] + + +class Schema: + ANSWER: str = "Answer" + QUESTION: str = "Question" + SUBJECT: str = "Subject" + OPTIONS: list[str] = ["A", "B", "C", "D"] + + +def download_mmmlu_datasets(languages: list[str]) -> dict[str, list[dict]]: + OPENAI_PUBLIC_URL = "https://openaipublic.blob.core.windows.net/simple-evals/{}" + data_dir = Path(__file__).absolute().parent + mmmlu_datasets = {} + for language in languages: + suffix = "mmlu.csv" if language == "EN-US" else f"mmlu_{language}.csv" + download_dst_path = data_dir / suffix + if os.path.exists(download_dst_path): + print(f"Skipping download of {suffix} because it already exists") + else: + url = OPENAI_PUBLIC_URL.format(suffix) + urllib.request.urlretrieve(url, download_dst_path) + if not os.path.exists(download_dst_path): + raise RuntimeError(f"Failed to download {suffix}") + + df = pandas.read_csv(download_dst_path, index_col=0) + examples = [row.to_dict() for _, row in df.iterrows()] + mmmlu_datasets[language] = examples + return mmmlu_datasets + + +def format_multichoice_question(row): + return QUERY_TEMPLATE_MULTICHOICE.format(**row) + + +def get_mcq_fields(entry: dict): + options_dict = {letter: entry[letter] for letter in Schema.OPTIONS} + options_text = "\n".join(f"{letter}) {option}" for letter, option in options_dict.items()) + prompt = format_multichoice_question(entry) + return {"question": prompt, "options": options_text, **options_dict} diff --git a/nemo_skills/dataset/mmmlu/prepare.py b/nemo_skills/dataset/mmmlu/prepare.py new file mode 100644 index 0000000000..b770ec41b5 --- /dev/null +++ b/nemo_skills/dataset/mmmlu/prepare.py @@ -0,0 +1,86 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +from pathlib import Path + +from nemo_skills.dataset.mmmlu.mmmlu_utils import ( + MULTILINGUAL_ANSWER_PATTERN_TEMPLATE, + MULTILINGUAL_ANSWER_REGEXES, + SUPPORTED_LANGUAGES, + Schema, + download_mmmlu_datasets, + get_mcq_fields, + subject2category, +) + + +def format_entry(entry: dict, language: str) -> dict: + expected_answer = entry[Schema.ANSWER] + category = subject2category.get(entry[Schema.SUBJECT], "other") + regexes = [ + MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex) for answer_regex in MULTILINGUAL_ANSWER_REGEXES + ] + LETTER_REGEX = r"\b\(?\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])\s*\)?\.?\b" + GREEDY_REGEX = r"[\s\S]*" + LETTER_REGEX + regexes.append(GREEDY_REGEX) # Matches the last A/B/C/D letter in the response + return { + "expected_answer": expected_answer, + "extract_from_boxed": False, + "extract_regex": regexes, + "subset_for_metrics": language, + "relaxed": False, + "category": category, + **get_mcq_fields(entry), + } + + +def main(args): + languages = [lang for lang in args.languages if lang != "EN-US"] + valid_languages = set(SUPPORTED_LANGUAGES) + if args.include_english: + valid_languages.add("EN-US") + languages.append("EN-US") + + invalid = set(languages) - valid_languages + if invalid: + raise ValueError(f"Unsupported languages: {invalid}") + datasets = download_mmmlu_datasets(languages) + + data_dir = Path(__file__).absolute().parent + output_file = data_dir / "test.jsonl" + with open(output_file, "wt", encoding="utf-8") as fout: + for language, examples in datasets.items(): + for entry in examples: + entry = format_entry(entry=entry, language=language) + json.dump(entry, fout, ensure_ascii=False) + fout.write("\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--languages", + default=SUPPORTED_LANGUAGES, + nargs="+", + help="Languages to process.", + ) + parser.add_argument( + "--include_english", + action="store_true", + help="Include English split which corresponds to the original MMLU dataset.", + ) + args = parser.parse_args() + main(args) diff --git a/nemo_skills/evaluation/evaluator/mcq.py b/nemo_skills/evaluation/evaluator/mcq.py index 821f1a47f8..b6076ca807 100644 --- a/nemo_skills/evaluation/evaluator/mcq.py +++ b/nemo_skills/evaluation/evaluator/mcq.py @@ -25,11 +25,32 @@ LOG = logging.getLogger(get_logger_name(__file__)) +def normalize_extracted_answer(extracted_answer: str) -> str: + return ( + # In arabic these are the letters used for A-D in multiple choice questions + extracted_answer.replace("أ", " A") + .replace("ب", " B") + .replace("ج", " C") + .replace("د", " D") + # In Bengali these are the letters used for A-D in multiple choice questions + .replace("অ", " A") + .replace("ব", " B") + .replace("ড", " C") + .replace("ঢ", " D") + # In Japanese these are the letters sometimes used for A-D in multiple choice questions + .replace("A", " A") + .replace("B", " B") + .replace("C", " C") + .replace("D", " D") + .strip() + ) + + @nested_dataclass(kw_only=True) class MCQEvaluatorConfig(BaseEvaluatorConfig): extract_from_boxed: bool = True # only used if extract_from_boxed is False - extract_regex: str = r"The final answer is (.+)$" + extract_regex: str | list[str] = r"The final answer is (.+)$" # if relaxed is True: # extract from regex FIRST, if not found, extract from boxed # if relaxed is False: @@ -42,22 +63,37 @@ def eval_mcq(cfg): eval_config = MCQEvaluatorConfig(**cfg) def extract_letter( - text, extract_from_boxed: bool = True, extract_regex: str = r"The final answer is (.+)$", relaxed=False + text, + extract_from_boxed: bool = True, + extract_regex: str | list[str] = r"The final answer is (.+)$", + relaxed=False, ): # extract prediction from boxed{} or regex - extracted_answer = extract_answer( - text, extract_from_boxed=extract_from_boxed, extract_regex=extract_regex, relaxed=relaxed - ) - parsed_letter = None + extracted_answer = None + if isinstance(extract_regex, list): + for regex in extract_regex: + extracted_answer = extract_answer( + text, extract_from_boxed=extract_from_boxed, extract_regex=regex, relaxed=relaxed + ) + if extracted_answer is not None: + break + else: + extracted_answer = extract_answer( + text, extract_from_boxed=extract_from_boxed, extract_regex=extract_regex, relaxed=relaxed + ) + + if extracted_answer is not None: + extracted_answer = normalize_extracted_answer(extracted_answer) + parsed_letter = None if extracted_answer is not None: if len(extracted_answer) == 1: - parsed_letter = extracted_answer + parsed_letter = extracted_answer.upper() elif len(extracted_answer) > 1: # try to extract the letter from extracted answer, useful to match , {A}, *A*, etc. match = re.findall(r"\b[A-Z]\b(?!.*\b[A-Z]\b)", extracted_answer, re.DOTALL) if len(match) > 0: - parsed_letter = match[-1].strip() + parsed_letter = match[-1].strip().upper() # adapted from https://artificialanalysis.ai/methodology/intelligence-benchmarking#intelligence-index-evaluation-suite-overview if parsed_letter is None: