Skip to content

Commit

Permalink
[Feature] Add CMMLU dataset (open-compass#91)
Browse files Browse the repository at this point in the history
* add CMMLU

* debug cmmlu

* add slurm args `qos`

* fix format: space before comment

* remove unused variable

* change the location of `answer is`

---------

Co-authored-by: 李浩楠 <[email protected]>
Co-authored-by: 李浩楠 <haonan.li>
Co-authored-by: Leymore <[email protected]>
  • Loading branch information
3 people authored Jul 25, 2023
1 parent d3ef665 commit 3c366dd
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 0 deletions.
4 changes: 4 additions & 0 deletions configs/datasets/cmmlu/cmmlu_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from mmengine.config import read_base

with read_base():
from .cmmlu_gen_ffe7c0 import cmmlu_datasets # noqa: F401, F403
122 changes: 122 additions & 0 deletions configs/datasets/cmmlu/cmmlu_gen_ffe7c0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import CMMLUDataset
from opencompass.utils.text_postprocessors import first_capital_postprocess

cmmlu_subject_mapping = {
'agronomy': '农学',
'anatomy': '解剖学',
'ancient_chinese': '古汉语',
'arts': '艺术学',
'astronomy': '天文学',
'business_ethics': '商业伦理',
'chinese_civil_service_exam': '中国公务员考试',
'chinese_driving_rule': '中国驾驶规则',
'chinese_food_culture': '中国饮食文化',
'chinese_foreign_policy': '中国外交政策',
'chinese_history': '中国历史',
'chinese_literature': '中国文学',
'chinese_teacher_qualification': '中国教师资格',
'clinical_knowledge': '临床知识',
'college_actuarial_science': '大学精算学',
'college_education': '大学教育学',
'college_engineering_hydrology': '大学工程水文学',
'college_law': '大学法律',
'college_mathematics': '大学数学',
'college_medical_statistics': '大学医学统计',
'college_medicine': '大学医学',
'computer_science': '计算机科学',
'computer_security': '计算机安全',
'conceptual_physics': '概念物理学',
'construction_project_management': '建设工程管理',
'economics': '经济学',
'education': '教育学',
'electrical_engineering': '电气工程',
'elementary_chinese': '小学语文',
'elementary_commonsense': '小学常识',
'elementary_information_and_technology': '小学信息技术',
'elementary_mathematics': '初等数学',
'ethnology': '民族学',
'food_science': '食品科学',
'genetics': '遗传学',
'global_facts': '全球事实',
'high_school_biology': '高中生物',
'high_school_chemistry': '高中化学',
'high_school_geography': '高中地理',
'high_school_mathematics': '高中数学',
'high_school_physics': '高中物理学',
'high_school_politics': '高中政治',
'human_sexuality': '人类性行为',
'international_law': '国际法学',
'journalism': '新闻学',
'jurisprudence': '法理学',
'legal_and_moral_basis': '法律与道德基础',
'logical': '逻辑学',
'machine_learning': '机器学习',
'management': '管理学',
'marketing': '市场营销',
'marxist_theory': '马克思主义理论',
'modern_chinese': '现代汉语',
'nutrition': '营养学',
'philosophy': '哲学',
'professional_accounting': '专业会计',
'professional_law': '专业法学',
'professional_medicine': '专业医学',
'professional_psychology': '专业心理学',
'public_relations': '公共关系',
'security_study': '安全研究',
'sociology': '社会学',
'sports_science': '体育学',
'traditional_chinese_medicine': '中医中药',
'virology': '病毒学',
'world_history': '世界历史',
'world_religions': '世界宗教'
}


cmmlu_all_sets = list(cmmlu_subject_mapping.keys())

cmmlu_datasets = []
for _name in cmmlu_all_sets:
_ch_name = cmmlu_subject_mapping[_name]
cmmlu_infer_cfg = dict(
ice_template=dict(
type=PromptTemplate,
template=dict(
begin="</E>",
round=[
dict(
role="HUMAN",
prompt=
f"以下是关于{_ch_name}的单项选择题,请直接给出正确答案的选项。\n题目:{{question}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}"
),
dict(role="BOT", prompt='答案是: {answer}'),
]),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
)

cmmlu_eval_cfg = dict(
evaluator=dict(type=AccEvaluator),
pred_postprocessor=dict(type=first_capital_postprocess))

cmmlu_datasets.append(
dict(
type=CMMLUDataset,
path="./data/cmmlu/",
name=_name,
abbr=f"cmmlu-{_name}",
reader_cfg=dict(
input_columns=["question", "A", "B", "C", "D"],
output_column="answer",
train_split="dev",
test_split='test'),
infer_cfg=cmmlu_infer_cfg,
eval_cfg=cmmlu_eval_cfg,
))

del _name, _ch_name
4 changes: 4 additions & 0 deletions configs/datasets/cmmlu/cmmlu_ppl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from mmengine.config import read_base

with read_base():
from .cmmlu_ppl_fd1f2f import cmmlu_datasets # noqa: F401, F403
122 changes: 122 additions & 0 deletions configs/datasets/cmmlu/cmmlu_ppl_fd1f2f.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import CMMLUDataset
from opencompass.utils.text_postprocessors import first_capital_postprocess

cmmlu_subject_mapping = {
'agronomy': '农学',
'anatomy': '解剖学',
'ancient_chinese': '古汉语',
'arts': '艺术学',
'astronomy': '天文学',
'business_ethics': '商业伦理',
'chinese_civil_service_exam': '中国公务员考试',
'chinese_driving_rule': '中国驾驶规则',
'chinese_food_culture': '中国饮食文化',
'chinese_foreign_policy': '中国外交政策',
'chinese_history': '中国历史',
'chinese_literature': '中国文学',
'chinese_teacher_qualification': '中国教师资格',
'clinical_knowledge': '临床知识',
'college_actuarial_science': '大学精算学',
'college_education': '大学教育学',
'college_engineering_hydrology': '大学工程水文学',
'college_law': '大学法律',
'college_mathematics': '大学数学',
'college_medical_statistics': '大学医学统计',
'college_medicine': '大学医学',
'computer_science': '计算机科学',
'computer_security': '计算机安全',
'conceptual_physics': '概念物理学',
'construction_project_management': '建设工程管理',
'economics': '经济学',
'education': '教育学',
'electrical_engineering': '电气工程',
'elementary_chinese': '小学语文',
'elementary_commonsense': '小学常识',
'elementary_information_and_technology': '小学信息技术',
'elementary_mathematics': '初等数学',
'ethnology': '民族学',
'food_science': '食品科学',
'genetics': '遗传学',
'global_facts': '全球事实',
'high_school_biology': '高中生物',
'high_school_chemistry': '高中化学',
'high_school_geography': '高中地理',
'high_school_mathematics': '高中数学',
'high_school_physics': '高中物理学',
'high_school_politics': '高中政治',
'human_sexuality': '人类性行为',
'international_law': '国际法学',
'journalism': '新闻学',
'jurisprudence': '法理学',
'legal_and_moral_basis': '法律与道德基础',
'logical': '逻辑学',
'machine_learning': '机器学习',
'management': '管理学',
'marketing': '市场营销',
'marxist_theory': '马克思主义理论',
'modern_chinese': '现代汉语',
'nutrition': '营养学',
'philosophy': '哲学',
'professional_accounting': '专业会计',
'professional_law': '专业法学',
'professional_medicine': '专业医学',
'professional_psychology': '专业心理学',
'public_relations': '公共关系',
'security_study': '安全研究',
'sociology': '社会学',
'sports_science': '体育学',
'traditional_chinese_medicine': '中医中药',
'virology': '病毒学',
'world_history': '世界历史',
'world_religions': '世界宗教'
}


cmmlu_all_sets = list(cmmlu_subject_mapping.keys())

cmmlu_datasets = []
for _name in cmmlu_all_sets:
_ch_name = cmmlu_subject_mapping[_name]
cmmlu_infer_cfg = dict(
ice_template=dict(
type=PromptTemplate,
template={
answer: dict(
begin="</E>",
round=[
dict(
role="HUMAN",
prompt=f"以下是关于{_ch_name}的单项选择题,请直接给出正确答案的选项。\n题目:{{question}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}"
),
dict(role="BOT", prompt=f'答案是: {answer}'),
])
for answer in ["A", "B", "C", "D"]
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
)

cmmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator))

cmmlu_datasets.append(
dict(
type=CMMLUDataset,
path="./data/cmmlu/",
name=_name,
abbr=f"cmmlu-{_name}",
reader_cfg=dict(
input_columns=["question", "A", "B", "C", "D"],
output_column="answer",
train_split="dev",
test_split='test'),
infer_cfg=cmmlu_infer_cfg,
eval_cfg=cmmlu_eval_cfg,
))

del _name, _ch_name
1 change: 1 addition & 0 deletions opencompass/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .chid import * # noqa: F401, F403
from .civilcomments import * # noqa: F401, F403
from .cluewsc import * # noqa: F401, F403
from .cmmlu import * # noqa: F401, F403
from .cmnli import * # noqa: F401, F403
from .cmrc import * # noqa: F401, F403
from .commonsenseqa import * # noqa: F401, F403
Expand Down
34 changes: 34 additions & 0 deletions opencompass/datasets/cmmlu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import csv
import os.path as osp

from datasets import Dataset, DatasetDict

from opencompass.registry import LOAD_DATASET

from .base import BaseDataset


@LOAD_DATASET.register_module()
class CMMLUDataset(BaseDataset):

@staticmethod
def load(path: str, name: str):
dataset = DatasetDict()
for split in ['dev', 'test']:
raw_data = []
filename = osp.join(path, split, f'{name}.csv')
with open(filename, encoding='utf-8') as f:
reader = csv.reader(f)
_ = next(reader) # skip the header
for row in reader:
assert len(row) == 7
raw_data.append({
'question': row[1],
'A': row[2],
'B': row[3],
'C': row[4],
'D': row[5],
'answer': row[6],
})
dataset[split] = Dataset.from_list(raw_data)
return dataset
5 changes: 5 additions & 0 deletions opencompass/runners/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class SlurmRunner(BaseRunner):
retry (int): Number of retries if the job failed. Defaults to 2.
partition (str): Slurm partition name. Defaults to None.
quotatype (str): Slurm quota type. Defaults to None.
qos (str): Slurm quality of service. Defaults to None.
debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None.
"""
Expand All @@ -38,13 +39,15 @@ def __init__(self,
retry: int = 2,
partition: str = None,
quotatype: str = None,
qos: str = None,
debug: bool = False,
lark_bot_url: str = None):
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
self.max_num_workers = max_num_workers
self.retry = retry
self.partition = partition
self.quotatype = quotatype
self.qos = qos

def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
"""Launch multiple tasks.
Expand Down Expand Up @@ -97,6 +100,8 @@ def _launch(self, task_cfg: ConfigDict, random_sleep: bool = True):
tmpl += f' -p {self.partition}'
if self.quotatype:
tmpl += f' --quotatype={self.quotatype}'
if self.qos:
tmpl += f' --qos={self.qos}'
if num_gpus > 0:
tmpl += f' --gres=gpu:{num_gpus}'
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
Expand Down
6 changes: 6 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def parse_slurm_args(slurm_parser):
help='Slurm quota type',
default=None,
type=str)
slurm_parser.add_argument('--qos',
help='Slurm quality of service',
default=None,
type=str)


def parse_dlc_args(dlc_parser):
Expand Down Expand Up @@ -286,6 +290,7 @@ def exec_infer_runner(tasks, args, cfg):
max_num_workers=args.max_num_workers,
partition=args.partition,
quotatype=args.quotatype,
qos=args.qos,
retry=args.retry,
debug=args.debug,
lark_bot_url=cfg['lark_bot_url'])
Expand All @@ -311,6 +316,7 @@ def exec_eval_runner(tasks, args, cfg):
max_num_workers=args.max_num_workers,
partition=args.partition,
quotatype=args.quotatype,
qos=args.qos,
retry=args.retry,
debug=args.debug,
lark_bot_url=cfg['lark_bot_url'])
Expand Down

0 comments on commit 3c366dd

Please sign in to comment.