Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…nto add_split_param
  • Loading branch information
DesmonDay committed Oct 22, 2024
2 parents ae9ddce + 76a118b commit 4ab0df1
Show file tree
Hide file tree
Showing 48 changed files with 3,007 additions and 5,424 deletions.
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,19 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩

* 大模型预训练、精调(包含 SFT、PEFT 技术)、对齐、量化已支持 LLaMA 系列、Baichuan 系列、Bloom 系列、ChatGLM 系列、Mistral 系列、OPT 系列和 Qwen 系列,【LLM】模型预训练、精调、对齐、量化支持列表如下:

| 模型名称/能力支持 | Pretrain | SFT | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert |
|:------------------:|:--------:|:---:|:----:|:-------------:|:---:|:----:|:------------:|:-------------:|
| Llama |||||||||
| Qwen |||||| 🚧 | 🚧 ||
| Mixtral ||| || 🚧 | 🚧 | 🚧 | 🚧 |
| Mistral ||| ||| 🚧 | 🚧 ||
| Baichuan/Baichuan2 |||||| 🚧 |||
| ChatGLM-6B ||| || 🚧 | 🚧 |||
| ChatGLM2/ChatGLM3 ||| || 🚧 | 🚧 |||
| Bloom ||| || 🚧 | 🚧 |||
| GPT-3 ||| 🚧 | 🚧 | 🚧 | 🚧 | 🚧 ||
| OPT ||| | 🚧 | 🚧 | 🚧 | 🚧 ||
| Yuan2 ||| | 🚧 | 🚧 | 🚧 | 🚧 ||
| 模型名称/能力支持 | Pretrain | SFT | FlashMask | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert |
|:------------------:|:--------:|:---:|:---------:|:----:|:-------------:|:---:|:----:|:------------:|:-------------:|
| Llama ||| | ||||||
| Qwen ||| | ||| 🚧 | 🚧 ||
| Mixtral ||| 🚧 | || 🚧 | 🚧 | 🚧 | 🚧 |
| Mistral ||| 🚧 | ||| 🚧 | 🚧 ||
| Baichuan/Baichuan2 ||| | ||| 🚧 |||
| ChatGLM-6B ||| 🚧 | || 🚧 | 🚧 |||
| ChatGLM2/ChatGLM3 ||| 🚧 | || 🚧 | 🚧 |||
| Bloom ||| 🚧 | || 🚧 | 🚧 |||
| GPT-3 ||| 🚧 | 🚧 | 🚧 | 🚧 | 🚧 | 🚧 ||
| OPT ||| 🚧 | | 🚧 | 🚧 | 🚧 | 🚧 ||
| Yuan2 ||| 🚧 | | 🚧 | 🚧 | 🚧 | 🚧 ||
------------------------------------------------------------------------------------------

* [大模型推理](./llm/docs/predict/inference.md)已支持 LLaMA 系列、Qwen 系列、Mistral 系列、ChatGLM 系列、Bloom 系列和 Baichuan 系列,支持 Weight Only INT8及 INT4推理,支持 WAC(权重、激活、Cache KV)进行 INT8、FP8量化的推理,【LLM】模型推理支持列表如下:
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
大模型统一存储文档 <llm/docs/unified_checkpoint.md>
混合并行训练教程 <llm/docs/llm_trainer.rst>
模型权重转换教程 <llm/docs/torch2paddle.md>
大模型DPO文档 <llm/docs/dpo.md>

.. toctree::
:maxdepth: 1
Expand Down
1 change: 1 addition & 0 deletions docs/llm/docs/dpo.md
10 changes: 7 additions & 3 deletions llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,21 @@

## 🛠️ 支持模型列表 🛠️

| Model | Pretrain | SFT | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert |
| Model | Pretrain | SFT | LoRA | Prefix Tuning | DPO/SimPO/ORPO | RLHF | Quantization | Torch convert |
|----------------------------------------|----------|-----|------|---------------|-----|------|--------------|---------------|
| [LLaMA](./config/llama) |||||||||
| [Qwen](./config/qwen) |||||| 🚧 | 🚧 ||
| [Mixtral](./config/mixtral) ||||| 🚧 | 🚧 | 🚧 | 🚧 |
| [Mixtral](./config/mixtral) ||||| | 🚧 | 🚧 | 🚧 |
| [Mistral](./config/mistral) |||||| 🚧 | 🚧 ||
| [Baichuan/Baichuan2](./config/llama) |||||| 🚧 |||
| [ChatGLM-6B](./config/chatglm) ||||| 🚧 | 🚧 |||
| [ChatGLM2/ChatGLM3](./config/chatglm2) ||||| 🚧 | 🚧 |||
| [ChatGLM2/ChatGLM3](./config/chatglm2) ||||| | 🚧 |||
| [Bloom](./config/bloom) ||||| 🚧 | 🚧 |||
| [GPT-3](./config/gpt-3) ||| 🚧 | 🚧 | 🚧 | 🚧 | 🚧 ||
| [OPT](./config/opt) | 🚧 ||| 🚧 | 🚧 | 🚧 | 🚧 ||
| [Gemma](./config/gemma) | 🚧 ||🚧 | 🚧 || 🚧 | 🚧 | 🚧 |
| [Yuan](./config/yuan) |||| 🚧 || 🚧 | 🚧 | 🚧 |


- ✅: Supported
- 🚧: In Progress
Expand Down Expand Up @@ -193,6 +196,7 @@ tar -zxvf ultrafeedback_binarized.tar.gz
# DPO 启动命令参考
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo/run_dpo.py ./config/llama/dpo_argument.json
```
更多 DPO 技术细节和使用说明详见[DPO 文档](./docs/dpo.md)

#### 3.2 RLHF

Expand Down
6 changes: 1 addition & 5 deletions llm/alignment/dpo/dpo_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,11 @@ class DPOConfig:

beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
simpo_gamma: float = field(default=0.5, metadata={"help": "the gamma parameter for SimPO loss"})
normalize_logps: bool = field(
default=True,
metadata={"help": "Apply logprobs normalization."},
)
label_smoothing: float = field(default=0.0, metadata={"help": "label_smoothing ratio"})
loss_type: str = field(default="sigmoid", metadata={"help": "DPO loss type"})
pref_loss_ratio: float = field(default=1.0, metadata={"help": "DPO loss ratio"})
sft_loss_ratio: float = field(default=0.0, metadata={"help": "SFT loss ratio"})
dpop_lambda: float = field(default=50, metadata={"help": "SFT loss ratio"})
dpop_lambda: float = field(default=50, metadata={"help": "dpop_lambda"})
ref_model_update_steps: int = field(default=-1, metadata={"help": "Update ref model state dict "})
reference_free: bool = field(default=False, metadata={"help": "No reference model."})
lora: bool = field(default=False, metadata={"help": "Use LoRA model."})
Expand Down
40 changes: 40 additions & 0 deletions llm/config/deepseek-v2/pretrain_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"model_name_or_path": "deepseek-ai/DeepSeek-V2-Lite",
"tokenizer_name_or_path": "deepseek-ai/DeepSeek-V2-Lite",
"input_dir": "./data",
"output_dir": "./checkpoints/pretrain_ckpts",
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 1,
"per_device_eval_batch_size": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"sharding_parallel_degree": 1,
"sharding": "stage2",
"virtual_pp_degree": 1,
"sequence_parallel": 0,
"use_flash_attention": true,
"max_seq_length": 4096,
"learning_rate": 3e-05,
"min_learning_rate": 3e-06,
"warmup_steps": 30,
"logging_steps": 1,
"max_steps": 10000,
"save_steps": 5000,
"eval_steps": 1000,
"weight_decay": 0.01,
"bf16": true,
"fp16_opt_level": "O2",
"warmup_ratio": 0.01,
"max_grad_norm": 1.0,
"dataloader_num_workers": 1,
"continue_training": 1,
"do_train": true,
"do_eval": true,
"do_predict": true,
"disable_tqdm": true,
"recompute": true,
"distributed_dataloader": 1,
"recompute_granularity": "full",
"unified_checkpoint": true,
"save_total_limit": 2
}
33 changes: 33 additions & 0 deletions llm/config/deepseek-v2/sft_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"model_name_or_path": "deepseek-ai/DeepSeek-V2-Lite",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/sft_ckpts",
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-05,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"sharding": "stage2",
"zero_padding": false,
"unified_checkpoint": true,
"use_flash_attention": true
}
Loading

0 comments on commit 4ab0df1

Please sign in to comment.