From 3c366ddfab9fa9f99327da84c1ebab9b04d6a82b Mon Sep 17 00:00:00 2001 From: Haonan Li Date: Tue, 25 Jul 2023 06:14:27 +0400 Subject: [PATCH] [Feature] Add CMMLU dataset (#91) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add CMMLU * debug cmmlu * add slurm args `qos` * fix format: space before comment * remove unused variable * change the location of `answer is` --------- Co-authored-by: 李浩楠 Co-authored-by: 李浩楠 Co-authored-by: Leymore --- configs/datasets/cmmlu/cmmlu_gen.py | 4 + configs/datasets/cmmlu/cmmlu_gen_ffe7c0.py | 122 +++++++++++++++++++++ configs/datasets/cmmlu/cmmlu_ppl.py | 4 + configs/datasets/cmmlu/cmmlu_ppl_fd1f2f.py | 122 +++++++++++++++++++++ opencompass/datasets/__init__.py | 1 + opencompass/datasets/cmmlu.py | 34 ++++++ opencompass/runners/slurm.py | 5 + run.py | 6 + 8 files changed, 298 insertions(+) create mode 100644 configs/datasets/cmmlu/cmmlu_gen.py create mode 100644 configs/datasets/cmmlu/cmmlu_gen_ffe7c0.py create mode 100644 configs/datasets/cmmlu/cmmlu_ppl.py create mode 100644 configs/datasets/cmmlu/cmmlu_ppl_fd1f2f.py create mode 100644 opencompass/datasets/cmmlu.py diff --git a/configs/datasets/cmmlu/cmmlu_gen.py b/configs/datasets/cmmlu/cmmlu_gen.py new file mode 100644 index 000000000..0245f871f --- /dev/null +++ b/configs/datasets/cmmlu/cmmlu_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .cmmlu_gen_ffe7c0 import cmmlu_datasets # noqa: F401, F403 diff --git a/configs/datasets/cmmlu/cmmlu_gen_ffe7c0.py b/configs/datasets/cmmlu/cmmlu_gen_ffe7c0.py new file mode 100644 index 000000000..e4b0bc4f6 --- /dev/null +++ b/configs/datasets/cmmlu/cmmlu_gen_ffe7c0.py @@ -0,0 +1,122 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import FixKRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import CMMLUDataset +from opencompass.utils.text_postprocessors import first_capital_postprocess + +cmmlu_subject_mapping = { + 'agronomy': '农学', + 'anatomy': '解剖学', + 'ancient_chinese': '古汉语', + 'arts': '艺术学', + 'astronomy': '天文学', + 'business_ethics': '商业伦理', + 'chinese_civil_service_exam': '中国公务员考试', + 'chinese_driving_rule': '中国驾驶规则', + 'chinese_food_culture': '中国饮食文化', + 'chinese_foreign_policy': '中国外交政策', + 'chinese_history': '中国历史', + 'chinese_literature': '中国文学', + 'chinese_teacher_qualification': '中国教师资格', + 'clinical_knowledge': '临床知识', + 'college_actuarial_science': '大学精算学', + 'college_education': '大学教育学', + 'college_engineering_hydrology': '大学工程水文学', + 'college_law': '大学法律', + 'college_mathematics': '大学数学', + 'college_medical_statistics': '大学医学统计', + 'college_medicine': '大学医学', + 'computer_science': '计算机科学', + 'computer_security': '计算机安全', + 'conceptual_physics': '概念物理学', + 'construction_project_management': '建设工程管理', + 'economics': '经济学', + 'education': '教育学', + 'electrical_engineering': '电气工程', + 'elementary_chinese': '小学语文', + 'elementary_commonsense': '小学常识', + 'elementary_information_and_technology': '小学信息技术', + 'elementary_mathematics': '初等数学', + 'ethnology': '民族学', + 'food_science': '食品科学', + 'genetics': '遗传学', + 'global_facts': '全球事实', + 'high_school_biology': '高中生物', + 'high_school_chemistry': '高中化学', + 'high_school_geography': '高中地理', + 'high_school_mathematics': '高中数学', + 'high_school_physics': '高中物理学', + 'high_school_politics': '高中政治', + 'human_sexuality': '人类性行为', + 'international_law': '国际法学', + 'journalism': '新闻学', + 'jurisprudence': '法理学', + 'legal_and_moral_basis': '法律与道德基础', + 'logical': '逻辑学', + 'machine_learning': '机器学习', + 'management': '管理学', + 'marketing': '市场营销', + 'marxist_theory': '马克思主义理论', + 'modern_chinese': '现代汉语', + 'nutrition': '营养学', + 'philosophy': '哲学', + 'professional_accounting': '专业会计', + 'professional_law': '专业法学', + 'professional_medicine': '专业医学', + 'professional_psychology': '专业心理学', + 'public_relations': '公共关系', + 'security_study': '安全研究', + 'sociology': '社会学', + 'sports_science': '体育学', + 'traditional_chinese_medicine': '中医中药', + 'virology': '病毒学', + 'world_history': '世界历史', + 'world_religions': '世界宗教' +} + + +cmmlu_all_sets = list(cmmlu_subject_mapping.keys()) + +cmmlu_datasets = [] +for _name in cmmlu_all_sets: + _ch_name = cmmlu_subject_mapping[_name] + cmmlu_infer_cfg = dict( + ice_template=dict( + type=PromptTemplate, + template=dict( + begin="", + round=[ + dict( + role="HUMAN", + prompt= + f"以下是关于{_ch_name}的单项选择题,请直接给出正确答案的选项。\n题目:{{question}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}" + ), + dict(role="BOT", prompt='答案是: {answer}'), + ]), + ice_token="", + ), + retriever=dict(type=FixKRetriever), + inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), + ) + + cmmlu_eval_cfg = dict( + evaluator=dict(type=AccEvaluator), + pred_postprocessor=dict(type=first_capital_postprocess)) + + cmmlu_datasets.append( + dict( + type=CMMLUDataset, + path="./data/cmmlu/", + name=_name, + abbr=f"cmmlu-{_name}", + reader_cfg=dict( + input_columns=["question", "A", "B", "C", "D"], + output_column="answer", + train_split="dev", + test_split='test'), + infer_cfg=cmmlu_infer_cfg, + eval_cfg=cmmlu_eval_cfg, + )) + +del _name, _ch_name diff --git a/configs/datasets/cmmlu/cmmlu_ppl.py b/configs/datasets/cmmlu/cmmlu_ppl.py new file mode 100644 index 000000000..645494f8e --- /dev/null +++ b/configs/datasets/cmmlu/cmmlu_ppl.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .cmmlu_ppl_fd1f2f import cmmlu_datasets # noqa: F401, F403 diff --git a/configs/datasets/cmmlu/cmmlu_ppl_fd1f2f.py b/configs/datasets/cmmlu/cmmlu_ppl_fd1f2f.py new file mode 100644 index 000000000..eb9ea96b9 --- /dev/null +++ b/configs/datasets/cmmlu/cmmlu_ppl_fd1f2f.py @@ -0,0 +1,122 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import FixKRetriever +from opencompass.openicl.icl_inferencer import PPLInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import CMMLUDataset +from opencompass.utils.text_postprocessors import first_capital_postprocess + +cmmlu_subject_mapping = { + 'agronomy': '农学', + 'anatomy': '解剖学', + 'ancient_chinese': '古汉语', + 'arts': '艺术学', + 'astronomy': '天文学', + 'business_ethics': '商业伦理', + 'chinese_civil_service_exam': '中国公务员考试', + 'chinese_driving_rule': '中国驾驶规则', + 'chinese_food_culture': '中国饮食文化', + 'chinese_foreign_policy': '中国外交政策', + 'chinese_history': '中国历史', + 'chinese_literature': '中国文学', + 'chinese_teacher_qualification': '中国教师资格', + 'clinical_knowledge': '临床知识', + 'college_actuarial_science': '大学精算学', + 'college_education': '大学教育学', + 'college_engineering_hydrology': '大学工程水文学', + 'college_law': '大学法律', + 'college_mathematics': '大学数学', + 'college_medical_statistics': '大学医学统计', + 'college_medicine': '大学医学', + 'computer_science': '计算机科学', + 'computer_security': '计算机安全', + 'conceptual_physics': '概念物理学', + 'construction_project_management': '建设工程管理', + 'economics': '经济学', + 'education': '教育学', + 'electrical_engineering': '电气工程', + 'elementary_chinese': '小学语文', + 'elementary_commonsense': '小学常识', + 'elementary_information_and_technology': '小学信息技术', + 'elementary_mathematics': '初等数学', + 'ethnology': '民族学', + 'food_science': '食品科学', + 'genetics': '遗传学', + 'global_facts': '全球事实', + 'high_school_biology': '高中生物', + 'high_school_chemistry': '高中化学', + 'high_school_geography': '高中地理', + 'high_school_mathematics': '高中数学', + 'high_school_physics': '高中物理学', + 'high_school_politics': '高中政治', + 'human_sexuality': '人类性行为', + 'international_law': '国际法学', + 'journalism': '新闻学', + 'jurisprudence': '法理学', + 'legal_and_moral_basis': '法律与道德基础', + 'logical': '逻辑学', + 'machine_learning': '机器学习', + 'management': '管理学', + 'marketing': '市场营销', + 'marxist_theory': '马克思主义理论', + 'modern_chinese': '现代汉语', + 'nutrition': '营养学', + 'philosophy': '哲学', + 'professional_accounting': '专业会计', + 'professional_law': '专业法学', + 'professional_medicine': '专业医学', + 'professional_psychology': '专业心理学', + 'public_relations': '公共关系', + 'security_study': '安全研究', + 'sociology': '社会学', + 'sports_science': '体育学', + 'traditional_chinese_medicine': '中医中药', + 'virology': '病毒学', + 'world_history': '世界历史', + 'world_religions': '世界宗教' +} + + +cmmlu_all_sets = list(cmmlu_subject_mapping.keys()) + +cmmlu_datasets = [] +for _name in cmmlu_all_sets: + _ch_name = cmmlu_subject_mapping[_name] + cmmlu_infer_cfg = dict( + ice_template=dict( + type=PromptTemplate, + template={ + answer: dict( + begin="", + round=[ + dict( + role="HUMAN", + prompt=f"以下是关于{_ch_name}的单项选择题,请直接给出正确答案的选项。\n题目:{{question}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}" + ), + dict(role="BOT", prompt=f'答案是: {answer}'), + ]) + for answer in ["A", "B", "C", "D"] + }, + ice_token="", + ), + retriever=dict(type=FixKRetriever), + inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]), + ) + + cmmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) + + cmmlu_datasets.append( + dict( + type=CMMLUDataset, + path="./data/cmmlu/", + name=_name, + abbr=f"cmmlu-{_name}", + reader_cfg=dict( + input_columns=["question", "A", "B", "C", "D"], + output_column="answer", + train_split="dev", + test_split='test'), + infer_cfg=cmmlu_infer_cfg, + eval_cfg=cmmlu_eval_cfg, + )) + +del _name, _ch_name diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index 08616258f..2124fe0d9 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -11,6 +11,7 @@ from .chid import * # noqa: F401, F403 from .civilcomments import * # noqa: F401, F403 from .cluewsc import * # noqa: F401, F403 +from .cmmlu import * # noqa: F401, F403 from .cmnli import * # noqa: F401, F403 from .cmrc import * # noqa: F401, F403 from .commonsenseqa import * # noqa: F401, F403 diff --git a/opencompass/datasets/cmmlu.py b/opencompass/datasets/cmmlu.py new file mode 100644 index 000000000..634cc929e --- /dev/null +++ b/opencompass/datasets/cmmlu.py @@ -0,0 +1,34 @@ +import csv +import os.path as osp + +from datasets import Dataset, DatasetDict + +from opencompass.registry import LOAD_DATASET + +from .base import BaseDataset + + +@LOAD_DATASET.register_module() +class CMMLUDataset(BaseDataset): + + @staticmethod + def load(path: str, name: str): + dataset = DatasetDict() + for split in ['dev', 'test']: + raw_data = [] + filename = osp.join(path, split, f'{name}.csv') + with open(filename, encoding='utf-8') as f: + reader = csv.reader(f) + _ = next(reader) # skip the header + for row in reader: + assert len(row) == 7 + raw_data.append({ + 'question': row[1], + 'A': row[2], + 'B': row[3], + 'C': row[4], + 'D': row[5], + 'answer': row[6], + }) + dataset[split] = Dataset.from_list(raw_data) + return dataset diff --git a/opencompass/runners/slurm.py b/opencompass/runners/slurm.py index 646adcb86..ddb7808d7 100644 --- a/opencompass/runners/slurm.py +++ b/opencompass/runners/slurm.py @@ -28,6 +28,7 @@ class SlurmRunner(BaseRunner): retry (int): Number of retries if the job failed. Defaults to 2. partition (str): Slurm partition name. Defaults to None. quotatype (str): Slurm quota type. Defaults to None. + qos (str): Slurm quality of service. Defaults to None. debug (bool): Whether to run in debug mode. Defaults to False. lark_bot_url (str): Lark bot url. Defaults to None. """ @@ -38,6 +39,7 @@ def __init__(self, retry: int = 2, partition: str = None, quotatype: str = None, + qos: str = None, debug: bool = False, lark_bot_url: str = None): super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url) @@ -45,6 +47,7 @@ def __init__(self, self.retry = retry self.partition = partition self.quotatype = quotatype + self.qos = qos def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: """Launch multiple tasks. @@ -97,6 +100,8 @@ def _launch(self, task_cfg: ConfigDict, random_sleep: bool = True): tmpl += f' -p {self.partition}' if self.quotatype: tmpl += f' --quotatype={self.quotatype}' + if self.qos: + tmpl += f' --qos={self.qos}' if num_gpus > 0: tmpl += f' --gres=gpu:{num_gpus}' tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}' diff --git a/run.py b/run.py index aaa90ac19..661beed2e 100644 --- a/run.py +++ b/run.py @@ -129,6 +129,10 @@ def parse_slurm_args(slurm_parser): help='Slurm quota type', default=None, type=str) + slurm_parser.add_argument('--qos', + help='Slurm quality of service', + default=None, + type=str) def parse_dlc_args(dlc_parser): @@ -286,6 +290,7 @@ def exec_infer_runner(tasks, args, cfg): max_num_workers=args.max_num_workers, partition=args.partition, quotatype=args.quotatype, + qos=args.qos, retry=args.retry, debug=args.debug, lark_bot_url=cfg['lark_bot_url']) @@ -311,6 +316,7 @@ def exec_eval_runner(tasks, args, cfg): max_num_workers=args.max_num_workers, partition=args.partition, quotatype=args.quotatype, + qos=args.qos, retry=args.retry, debug=args.debug, lark_bot_url=cfg['lark_bot_url'])