自定义数据集视频介绍:
下面是文字版:
我们支持三种自定义数据集的方法.
- 【推荐】直接命令行传参的方式,指定
--dataset xxx.json yyy.jsonl zzz.csv
, 更加方便支持自定义数据集, 支持五种数据集格式(即使用SmartPreprocessor
,支持的数据集格式见下方), 支持dataset_id
和dataset_path
. 不需要修改dataset_info.json
文件. 该方法适合刚接触ms-swift的用户, 下两种方法适合对ms-swift进行拓展的开发者. - 添加数据集到
dataset_info.json
中, 比第一种方式更灵活但繁琐, 支持对数据集使用两种预处理器并指定其参数:RenameColumnsPreprocessor
,ConversationsPreprocessor
(默认使用SmartPreprocessor
). 支持直接修改swift内置的dataset_info.json
, 或者通过--custom_dataset_info xxx.json
的方式传入外置的json文件(方便pip install而非git clone的用户拓展数据集). - 注册数据集的方式: 比第1、2种方式更加灵活但繁琐, 支持使用函数对数据集进行预处理. 方法1、2在实现上借助了方法3. 可以直接修改源码进行拓展, 或者通过
--custom_register_path xxx.py
的方式传入, 脚本会对py文件进行解析(方便pip install的用户).
支持直接传入行自定义的dataset_id(兼容MS和HF)和dataset_path, 以及同时传入多个自定义数据集以及对应采样数, 脚本会进行自动的预处理和拼接. 如果传入的是dataset_id
, 默认会使用dataset_id中的'default'子数据集, 并设置split为'train'. 如果该dataset_id已经注册, 则会使用注册时传入的subsets、split以及预处理函数. 如果传入的是dataset_path
, 则可以指定为相对路径和绝对路径, 其中相对路径为相对于当前运行目录.
每个数据集指定格式如下: [HF or MS::]{dataset_name} or {dataset_id} or {dataset_path}[:subset1/subset2/...][#dataset_sample]
, 最简只需要指定dataset_name、dataset_id或者dataset_path即可.
# 默认使用modelscope的dataset_id, 同时也支持huggingface的dataset_id
--dataset {dataset_id} {dataset_path} HF::{dataset_id}
# 数据集混合: 以下取dataset_id中subset1和subset2子数据集并随机采样20000条. 如果不使用`#{dataset_sample}`, 则使用数据集中的所有样本
--dataset {dataset_name}#20000 {dataset_id}:{subset1}/{subset2}#20000 {dataset_path}#10000
脚本支持的文件格式包含csv
, json
, jsonl
格式. 你需要将传入的文件符合以下数据集格式(只列出了一部分). 以下格式都支持system (需要注意的是, csv如果指定了system字段, 则无法设置为None
, 只能指定为空字符串. json和jsonl没有这个限制). json
, jsonl
格式的文件支持多轮对话 (csv
不支持).
格式1:
预训练:
response
11111
aaaaa
AAAAA
{"response": "11111"}
{"response": "aaaaa"}
{"response": "AAAAA"}
单轮对话:
system,query,response
00000,11111,22222
00001,aaaaa,bbbbb
00002,AAAAA,BBBBB
{"system": "00000", "query": "11111", "response": "22222"}
{"query": "aaaaa", "response": "bbbbb"}
{"system": "00001", "query": "AAAAA", "response": "BBBBB"}
多轮对话:
{"system": "00000", "query": "55555", "response": "66666"}
{"query": "eeeee", "response": "fffff", "history": []}
{"system": "00001", "query": "EEEEE", "response": "FFFFF", "history": [["query1", "response1"], ["query2", "response2"]]}
[{"system": "00000", "query": "55555", "response": "66666"},
{"query": "eeeee", "response": "fffff", "history": []},
{"system": "00001", "query": "EEEEE", "response": "FFFFF", "history": [["query1", "response1"], ["query2", "response2"]]}]
格式2:
{"conversations": [{"from": "system", "value": "00000"}, {"from": "user", "value": "11111"}, {"from": "assistant", "value": "22222"}]}
{"conversations": [{"from": "user", "value": "aaaaa"}, {"from": "assistant", "value": "bbbbb"}, {"from": "user", "value": "ccccc"}, {"from": "assistant", "value": "ddddd"}]}
{"conversations": [{"from": "user", "value": "AAAAA"}, {"from": "assistant", "value": "BBBBB"}, {"from": "user", "value": "CCCCC"}, {"from": "assistant", "value": "DDDDD"}]}
格式3:
{"messages": [{"role": "system", "content": "00000"}, {"role": "user", "content": "11111"}, {"role": "assistant", "content": "22222"}]}
{"messages": [{"role": "user", "content": "aaaaa"}, {"role": "assistant", "content": "bbbbb"}, {"role": "user", "content": "ccccc"}, {"role": "assistant", "content": "ddddd"}]}
{"messages": [{"role": "user", "content": "AAAAA"}, {"role": "assistant", "content": "BBBBB"}, {"role": "user", "content": "CCCCC"}, {"role": "assistant", "content": "DDDDD"}]}
格式4:
{"system": "00000", "conversation": [{"human": "11111", "assistant": "22222"}]}
{"conversation": [{"human": "aaaaa", "assistant": "bbbbb"}]}
{"system": "00001", "conversation": [{"human": "AAAAA", "assistant": "BBBBB"}, {"human": "CCCCC", "assistant": "DDDDD"}, {"human": "EEEEE", "assistant": "FFFFF"}]}
格式5:
system,instruction,input,output
00000,11111,22222,33333
00001,aaaaa,bbbbb,ccccc
00002,AAAAA,BBBBB,CCCCC
额外的预训练格式:
{"text": "11111"}
{"text": "aaaaa"}
{"text": "AAAAA"}
人类对齐
语言模型(DPO/ORPO/SimPO/CPO)
{"system": "123", "query": "11111", "response": "22222", "rejected_response": "33333", "history": [["query1", "response1"], ["query2", "response2"]]}
{"system": "123", "query": "aaaaa", "response": "bbbbb", "rejected_response": "ccccc", "history": [["query1", "response1"], ["query2", "response2"]]}
{"system": "123", "query": "AAAAA", "response": "BBBBB", "rejected_response": "CCCCC", "history": [["query1", "response1"], ["query2", "response2"]]}
- 其中
system
和history
为可选项
语言模型 (KTO)
{"query": "11111", "response": "22222", "label": true}
{"query": "aaaaa", "response": "bbbbb", "label": false}
{"system": "123", "query": "AAAAA", "response": "BBBBB", "label": true, "history": [["query1", "response1"], ["query2", "response2"]]}
-
注意
label
需要是bool类型, 不能是字符串 -
其中
system
和history
为可选项
视觉多模态大模型(DPO/ORPO/SimPO/CPO)
{"system": "123", "query": "11111", "response": "22222", "rejected_response": "33333", "images": ["image_path"], "history": [["query1", "response1"], ["query2", "response2"]]}
{"system": "123", "query": "aaaaa", "response": "bbbbb", "rejected_response": "ccccc", "images": ["image_path"], "history": [["query1", "response1"], ["query2", "response2"]]}
{"system": "123", "query": "AAAAA", "response": "BBBBB", "rejected_response": "CCCCC", "images": ["image_path"], "history": [["query1", "response1"], ["query2", "response2"]]}
-
不同模型对图像数量的支持不同, 具体参考模型对应的最佳实践文档
-
其中
system
和history
为可选项
Tool-Calling Agent
格式1
{"tools":"{API_LIST}","conversations": [{"from": "system", "value": "00000"}, {"from": "user", "value": "11111"}, {"from": "assistant", "value": "22222"}]}
{"tools":"{API_LIST}","conversations": [{"from": "user", "value": "aaaaa"}, {"from": "assistant", "value": "bbbbb"}, {"from": "tool", "value": "ccccc"}, {"from": "assistant", "value": "ddddd"}]}
{"tools":"{API_LIST}","conversations": [{"from": "user", "value": "AAAAA"}, {"from": "assistant", "value": "BBBBB"}, {"from": "tool", "value": "CCCCC"}, {"from": "assistant", "value": "DDDDD"}]}
格式2
{"tools":"{API_LIST}","messages": [{"role": "system", "content": "00000"}, {"role": "user", "content": "11111"}, {"role": "assistant", "content": "22222"}]}
{"tools":"{API_LIST}","messages": [{"role": "user", "content": "aaaaa"}, {"role": "assistant", "content": "bbbbb"}, {"role": "tool", "content": "ccccc"}, {"role": "assistant", "content": "ddddd"}]}
{"tools":"{API_LIST}","messages": [{"role": "user", "content": "AAAAA"}, {"role": "assistant", "content": "BBBBB"}, {"role": "tool", "content": "CCCCC"}, {"role": "assistant", "content": "DDDDD"}]}
其中tools格式参考Agent部署文档, 你可以通过设置--tools_prompt
来选择对应的prompt
tool
字段表示工具调用返回结果
可以参考swift内置的dataset_info.json进行数据集拓展. 你可以直接在内置的dataset_info.json中添加, 也可以通过--custom_dataset_info 1.json
传入外置的dataset_info.json的路径、json字符串或者字典.
添加dataset_id:
# MS
# 使用: `--dataset <dataset_name>`
"<dataset_name>": {
"dataset_id": "xxx/xxx"
}
# HF
# 使用: `--dataset HF::<dataset_name>` 或者 直接使用`USE_HF`环境变量.
"<dataset_name>": {
"hf_dataset_id": "xxx/xxx"
}
添加dataset_path:
# 可以指定相对路径和绝对路径. 相对路径相对于dataset_info.json文件所在目录.
# 使用: `--dataset <dataset_name>`
"<dataset_name>": {
"dataset_path": "xxx"
}
支持以下参数:
- dataset_id: 数据集对应的ModelScope的dataset_id, 默认为
None
. 最简的设置必须指定dataset_id
、hf_dataset_id
和dataset_path
中的一个. - subsets: 子数据集的名字列表, 默认为
[]
, 即使用'default'子数据集. - split: 默认为
['train']
, 通常不需要修改. - hf_dataset_id: 数据集对应的HuggingFace的datasset_id, 默认为
None
. - dataset_path: 用于指定数据集的本地路径, e.g. 1.jsonl等, 默认为
None
. 可以传入相对路径和绝对路径. 如果使用相对路径, 则相对于dataset_info.json
文件所在目录. 如果设置了dataset_path, 那么dataset_id, subsets, hf_dataset_id参数失效. - columns: 默认使用的预处理器为
SmartPreprocessor
, 指定此参数则指定为RenameColumnsPreprocessor
, 你需要rename数据集中的列并转换为上述格式1的样式. - conversations: 指定此参数则指定预处理器为
ConversationsPreprocessor
('columns'的优先级高于'conversations'). - remove_useless_columns: 指定是否移除无用的列 (包括: 'query', 'response', 'rejected_response', 'system', 'history', 'images'), 默认为
True
, 通常不需要设置. - tags: 用于注释数据集, 默认为
[]
, 通常不需要设置.
如果dataset_info.json
中参数无法满足您的要求, 例如你需要添加自定义的prompt、需要对数据集提前进行清洗或者进行复杂的数据集获取与预处理, 则可以使用注册数据集的方式, 使用函数的方式来进行数据获取与预处理.
以下是一个注册数据集的案例. 完整的py文件可以查看custom.py, sh脚本可以查看custom. 你可以通过指定--custom_register_path xxx.py
对注册的内容进行解析.
from typing import Optional, Tuple
from datasets import Dataset as HfDataset
from modelscope import MsDataset
from swift.llm import get_dataset, register_dataset, get_dataset_from_repo
from swift.utils import get_logger
logger = get_logger()
class CustomDatasetName:
stsb_en = 'stsb-en'
def _preprocess_stsb(dataset: HfDataset) -> HfDataset:
prompt = """Task: Based on the given two sentences, provide a similarity score between 0.0 and 5.0.
Sentence 1: {text1}
Sentence 2: {text2}
Similarity score: """
query = []
response = []
for d in dataset:
query.append(prompt.format(text1=d['text1'], text2=d['text2']))
response.append(f"{d['label']:.1f}")
return HfDataset.from_dict({'query': query, 'response': response})
register_dataset(CustomDatasetName.stsb_en, 'swift/stsb', None, _preprocess_stsb, get_dataset_from_repo)
if __name__ == '__main__':
# test dataset
train_dataset, val_dataset = get_dataset([CustomDatasetName.stsb_en],
check_dataset_strategy='warning')
print(f'train_dataset: {train_dataset}')
print(f'val_dataset: {val_dataset}')
register_dataset
会在DATASET_MAPPING
中注册数据集, 该函数的参数含义如下:
-
dataset_name
: 必填项, 表示数据集的名字, 也是数据集的唯一id. -
dataset_id_or_path
: 必填项. 表示数据集在ModelScope Hub上的dataset_id
或者本地的dataset_dir
. -
subsets
: 数据集的子数据集列表, 默认为[]
. -
preprocess_func
: 预处理函数. -
get_function
: 默认值为None
. 获取数据集的函数. 如果传入None, 则使用修饰器方案进行数据集注册. 如果传入一个函数, 则使用正常方案进行注册.get_function
需要返回HfDataset
或Tuple[HfDataset, Optional[HfDataset]]
. 如果只返回一个数据集, 则该数据集为train_dataset. 如果返回两个数据集, 则分别作为train_dataset和val_dataset.get_dataset
函数支持获取多个数据集, 例如:get_dataset(['dataset1', 'dataset2'])
. 我们会将各个子数据集的训练集和验证集部分分别进行拼接, 最终返回合并后的训练集和验证集.函数返回的
HfDataset
需要符合一定的规范. 如果你要进行预训练, 那么只需要包含response
字段, 具体可以参考'tigerbot-law-zh'
数据集. 如果是指令微调(单轮对话)的情况下, 需包含query
,response
字段, 分别代表指令微调的用户询问和AI助手的回答, 具体可以参考'alpaca-zh'
数据集. 如果是多轮对话, 则需要额外加上history
字段, 代表对话的历史信息, 具体可以参考'damo-agent-mini-zh'
数据集. 如果每个数据集样例具有不同的system
, 则需要额外加上system字段, 具体你也可以参考'damo-agent-mini-zh'
数据集. -
**kwargs
: 其他用于注释数据集的参数. 该参数一般不需要设置.
以下是一个自定义模型的案例. 完整的py文件可以查看custom.py, sh脚本可以查看custom. 你可以通过指定--custom_register_path xxx.py
对注册的内容进行解析.
from typing import Any, Dict
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.utils.versions import require_version
from swift.llm import LoRATM, TemplateType, get_model_tokenizer, register_model
from swift.utils import get_logger
logger = get_logger()
class CustomModelType:
tigerbot_7b = 'tigerbot-7b'
tigerbot_13b = 'tigerbot-13b'
tigerbot_13b_chat = 'tigerbot-13b-chat'
class CustomTemplateType:
tigerbot = 'tigerbot'
@register_model(CustomModelType.tigerbot_7b,
'TigerResearch/tigerbot-7b-base-v3', LoRATM.llama,
TemplateType.default_generation)
@register_model(CustomModelType.tigerbot_13b,
'TigerResearch/tigerbot-13b-base-v2', LoRATM.llama,
TemplateType.default_generation)
@register_model(CustomModelType.tigerbot_13b_chat,
'TigerResearch/tigerbot-13b-chat-v4', LoRATM.llama,
CustomTemplateType.tigerbot)
def get_tigerbot_model_tokenizer(model_dir: str,
torch_dtype: torch.dtype,
model_kwargs: Dict[str, Any],
load_model: bool = True,
**kwargs):
use_flash_attn = kwargs.pop('use_flash_attn', False)
if use_flash_attn:
require_version('transformers>=4.34')
logger.info('Setting use_flash_attention_2: True')
model_kwargs['use_flash_attention_2'] = True
model_config = AutoConfig.from_pretrained(
model_dir, trust_remote_code=True)
model_config.pretraining_tp = 1
model_config.torch_dtype = torch_dtype
logger.info(f'model_config: {model_config}')
tokenizer = AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True)
model = None
if load_model:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
config=model_config,
torch_dtype=torch_dtype,
trust_remote_code=True,
**model_kwargs)
return model, tokenizer
if __name__ == '__main__':
# test model base
model, tokenizer = get_model_tokenizer(
CustomModelType.tigerbot_7b, use_flash_attn=False)
print(model.__class__.__name__)
# test model chat
model, tokenizer = get_model_tokenizer(
CustomModelType.tigerbot_13b_chat, use_flash_attn=False)
print(model.__class__.__name__)
register_model
会在MODEL_MAPPING
中注册模型, 该函数的参数含义如下:
model_type
: 必填项. 表示模型的名字, 也是唯一的id.model_id_or_path
: 必填项. 表示模型在ModelScope Hub中的model_id
, 或者是本地的模型目录model_dir
.lora_target_modules
: 默认为None
. 表示在sh脚本中指定--lora_target_modules DEFAULT
或--lora_target_modules AUTO
或未指定--lora_target_modules
情况下默认使用的lora_target_modules.template
: 默认为TemplateType.default
. 表示在sh脚本中指定--template_type AUTO
或未指定--template_type
情况下默认使用的对话模板.get_function
: 默认值为None
. 获取model和tokenizer的函数. 如果传入None, 则使用修饰器方案进行模型注册. 如果传入一个函数, 则使用正常方案进行注册.requires
: 默认为[]
. 表示模型所需要的区别于其他模型的依赖. 该参数一般不需要设置.torch_dtype
: 默认为None
. 表示模型所推荐使用的torch_dtype. 该参数一般不需要设置.revision
: 默认为None
. 用于指定模型的版本号. 如果model_id_or_path
是本地的模型目录, 则该参数失效. 该参数一般不需要设置.ignore_file_pattern
: 默认为None
. 表示下载的时候需要忽略的文件名的正则pattern, 该参数会传递给snapshot_download
. 例如r'.+\.bin$'
,r'.+\.savetensors$'
等. 该参数一般不需要设置.**kwargs
: 其他用于注释模型能力的参数. 该参数一般不需要设置.
以下是一个自定义模型的案例. 完整的py文件可以查看custom.py, sh脚本可以查看custom.
from swift.llm import (Template, ModelType, dataset_map,
get_model_tokenizer, get_template, get_dataset,
print_example, register_template, DatasetName)
from swift.utils import get_logger
logger = get_logger()
class CustomTemplateType:
tigerbot = 'tigerbot'
# Ref: https://github.com/TigerResearch/TigerBot/blob/main/infer.py
register_template(
CustomTemplateType.tigerbot,
Template(['{{SYSTEM}}'], ['\n\n### Instruction:\n{{QUERY}}\n\n### Response:\n'], [],
[['eos_token_id']]))
if __name__ == '__main__':
# test template
train_dataset, _ = get_dataset(DatasetName.blossom_math_zh)
_, tokenizer = get_model_tokenizer(ModelType.qwen_7b_chat, load_model=False)
template = get_template(CustomTemplateType.tigerbot, tokenizer)
train_dataset = dataset_map(train_dataset, template.encode)
print_example(train_dataset[0], tokenizer)
register_template
会在TEMPLATE_MAPPING
中注册对话模板, 该函数的参数含义如下:
template_type
: 必填项, 表示对话模板的名字, 也是template的唯一id.template
: 必填项, 需要传入一个Template
. 初始化Template
需要传入以下参数:prefix
,prompt
,chat_sep
,suffix
,default_system
.
模板初始化函数会根据这四个内容, 获取完整的chat template. 其中这四个配置内容的含义如下.
prefix
: 表示对话模板中的前缀部分, 一般为system部分, 前缀token, bos token等内容. 我们使用{{SYSTEM}}
作为system的占位符. 如果{{SYSTEM}}
没有在prefix中存在, 则该Template不支持system, e.g.damo-agent-mini-zh
数据集.prompt
: 表示对话模板中的一轮对话. 我们使用{{QUERY}}
作为每轮对话中, human询问部分的占位符,{{ROUND0}}
则表示本次对话是第几轮的占位符, 从0开始计数,{{ROUND1}}
从1开始计数. AI助手的回复部分会拼接在prompt
的后面, 因此我们没有设计其占位符. 我们只会对AI助手的回复部分计算损失.chat_sep
: 如果需要进行多轮对话,chat_sep
会作为每轮对话之间的分隔符, 例如: 换行等. 如果设置为None, 则该Template不支持多轮对话.suffix
: 作为对话模板的后缀部分, 一般为eos token. 会拼接在最后一轮的对话后面.default_system
: 默认的system.