-
* **2024.08.08 📚《飞桨产业级大语言模型开发利器 PaddleNLP 3.0 重磅发布》**,训压推全流程贯通,主流模型全覆盖。大模型自动并行,千亿模型训推全流程开箱即用。提供产业级高性能精调与对齐解决方案,压缩推理领先,多硬件适配。覆盖产业级智能助手、内容创作、知识问答、关键信息抽取等应用场景。直播时间:8月22日(周四)19:00。报名链接:https://www.wjx.top/vm/Y2f7FFY.aspx?udsid=143844
* **2024.06.27 [PaddleNLP v3.0 Beta](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v3.0.0-beta0)**:拥抱大模型,体验全升级。统一大模型套件,实现国产计算芯片全流程接入;全面支持飞桨4D 并行配置、高效精调策略、高效对齐算法、高性能推理等大模型产业级应用流程;自研极致收敛的 RsLoRA+算法、自动扩缩容存储机制 Unified Checkpoint 和通用化支持的 FastFFN、FusedQKV 助力大模型训推;主流模型持续支持更新,提供高效解决方案。
@@ -69,39 +74,45 @@
大模型套件高性能推理模块内置动态插入和全环节算子融合策略,极大加快并行推理速度。底层实现细节封装化,实现开箱即用的高性能并行推理能力。
+## 文档
+更多详细文档, 请访问 [PaddleNLP Documentation](https://paddlenlp.readthedocs.io/).
+
------------------------------------------------------------------------------------------
## 模型支持
* 模型参数已支持 LLaMA 系列、Baichuan 系列、Bloom 系列、ChatGLM 系列、Gemma 系列、Mistral 系列、OPT 系列和 Qwen 系列,详细列表👉【LLM】模型参数支持列表如下:
-| 模型系列 | 模型名称 |
-|:-------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| [LLaMA](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | facebook/llama-7b, facebook/llama-13b, facebook/llama-30b, facebook/llama-65b |
-| [Llama2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-2-7b, meta-llama/Llama-2-7b-chat, meta-llama/Llama-2-13b, meta-llama/Llama-2-13b-chat, meta-llama/Llama-2-70b, meta-llama/Llama-2-70b-chat |
-| [Llama3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3-8B, meta-llama/Meta-Llama-3-8B-Instruct, meta-llama/Meta-Llama-3-70B, meta-llama/Meta-Llama-3-70B-Instruct |
-| [Llama3.1](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3.1-8B, meta-llama/Meta-Llama-3.1-8B-Instruct, meta-llama/Meta-Llama-3.1-70B, meta-llama/Meta-Llama-3.1-70B-Instruct, meta-llama/Meta-Llama-3.1-405B, meta-llama/Meta-Llama-3.1-405B-Instruct, meta-llama/Llama-Guard-3-8B |
-| [Llama3.2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-3.2-1B, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.2-3B, meta-llama/Llama-3.2-3B-Instruct, meta-llama/Llama-Guard-3-1B |
-| [Llama3.3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-3.3-70B-Instruct |
-| [Baichuan](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan-7B, baichuan-inc/Baichuan-13B-Base, baichuan-inc/Baichuan-13B-Chat |
-| [Baichuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan2-7B-Base, baichuan-inc/Baichuan2-7B-Chat, baichuan-inc/Baichuan2-13B-Base, baichuan-inc/Baichuan2-13B-Chat |
-| [Bloom](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/bloom) | bigscience/bloom-560m, bigscience/bloom-560m-bf16, bigscience/bloom-1b1, bigscience/bloom-3b, bigscience/bloom-7b1, bigscience/bloomz-560m, bigscience/bloomz-1b1, bigscience/bloomz-3b, bigscience/bloomz-7b1-mt, bigscience/bloomz-7b1-p3, bigscience/bloomz-7b1, bellegroup/belle-7b-2m |
-| [ChatGLM](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm/) | THUDM/chatglm-6b, THUDM/chatglm-6b-v1.1 |
-| [ChatGLM2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm2-6b |
-| [ChatGLM3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm3-6b |
-| [DeepSeekV2](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V2, deepseek-ai/DeepSeek-V2-Chat, deepseek-ai/DeepSeek-V2-Lite, deepseek-ai/DeepSeek-V2-Lite-Chat, deepseek-ai/DeepSeek-Coder-V2-Base, deepseek-ai/DeepSeek-Coder-V2-Instruct, deepseek-ai/DeepSeek-Coder-V2-Lite-Base, deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct |
-| [Gemma](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/gemma) | google/gemma-7b, google/gemma-7b-it, google/gemma-2b, google/gemma-2b-it |
-| [Mistral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mistral) | mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-7B-v0.1 |
-| [Mixtral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mixtral) | mistralai/Mixtral-8x7B-Instruct-v0.1 |
-| [OPT](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/opt) | facebook/opt-125m, facebook/opt-350m, facebook/opt-1.3b, facebook/opt-2.7b, facebook/opt-6.7b, facebook/opt-13b, facebook/opt-30b, facebook/opt-66b, facebook/opt-iml-1.3b, opt-iml-max-1.3b |
-| [Qwen](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | qwen/qwen-7b, qwen/qwen-7b-chat, qwen/qwen-14b, qwen/qwen-14b-chat, qwen/qwen-72b, qwen/qwen-72b-chat, |
-| [Qwen1.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen1.5-0.5B, Qwen/Qwen1.5-0.5B-Chat, Qwen/Qwen1.5-1.8B, Qwen/Qwen1.5-1.8B-Chat, Qwen/Qwen1.5-4B, Qwen/Qwen1.5-4B-Chat, Qwen/Qwen1.5-7B, Qwen/Qwen1.5-7B-Chat, Qwen/Qwen1.5-14B, Qwen/Qwen1.5-14B-Chat, Qwen/Qwen1.5-32B, Qwen/Qwen1.5-32B-Chat, Qwen/Qwen1.5-72B, Qwen/Qwen1.5-72B-Chat, Qwen/Qwen1.5-110B, Qwen/Qwen1.5-110B-Chat, Qwen/Qwen1.5-MoE-A2.7B, Qwen/Qwen1.5-MoE-A2.7B-Chat |
-| [Qwen2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-0.5B, Qwen/Qwen2-0.5B-Instruct, Qwen/Qwen2-1.5B, Qwen/Qwen2-1.5B-Instruct, Qwen/Qwen2-7B, Qwen/Qwen2-7B-Instruct, Qwen/Qwen2-72B, Qwen/Qwen2-72B-Instruct, Qwen/Qwen2-57B-A14B, Qwen/Qwen2-57B-A14B-Instruct |
-| [Qwen2-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-Math-1.5B, Qwen/Qwen2-Math-1.5B-Instruct, Qwen/Qwen2-Math-7B, Qwen/Qwen2-Math-7B-Instruct, Qwen/Qwen2-Math-72B, Qwen/Qwen2-Math-72B-Instruct, Qwen/Qwen2-Math-RM-72B |
-| [Qwen2.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-0.5B, Qwen/Qwen2.5-0.5B-Instruct, Qwen/Qwen2.5-1.5B, Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-3B, Qwen/Qwen2.5-3B-Instruct, Qwen/Qwen2.5-7B, Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen2.5-14B, Qwen/Qwen2.5-14B-Instruct, Qwen/Qwen2.5-32B, Qwen/Qwen2.5-32B-Instruct, Qwen/Qwen2.5-72B, Qwen/Qwen2.5-72B-Instruct |
-| [Qwen2.5-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Math-1.5B, Qwen/Qwen2.5-Math-1.5B-Instruct, Qwen/Qwen2.5-Math-7B, Qwen/Qwen2.5-Math-7B-Instruct, Qwen/Qwen2.5-Math-72B, Qwen/Qwen2.5-Math-72B-Instruct, Qwen/Qwen2.5-Math-RM-72B |
-| [Qwen2.5-Coder](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Coder-1.5B, Qwen/Qwen2.5-Coder-1.5B-Instruct, Qwen/Qwen2.5-Coder-7B, Qwen/Qwen2.5-Coder-7B-Instruct |
-| [Yuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/yuan/) | IEITYuan/Yuan2-2B, IEITYuan/Yuan2-51B, IEITYuan/Yuan2-102B |
+| 模型系列 | 模型名称 |
+|:-------------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [PP-UIE](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/application/information_extraction) | paddlenlp/PP-UIE-0.5B, paddlenlp/PP-UIE-1.5B, paddlenlp/PP-UIE-7B, paddlenlp/PP-UIE-14B |
+| [LLaMA](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | facebook/llama-7b, facebook/llama-13b, facebook/llama-30b, facebook/llama-65b |
+| [Llama2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-2-7b, meta-llama/Llama-2-7b-chat, meta-llama/Llama-2-13b, meta-llama/Llama-2-13b-chat, meta-llama/Llama-2-70b, meta-llama/Llama-2-70b-chat |
+| [Llama3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3-8B, meta-llama/Meta-Llama-3-8B-Instruct, meta-llama/Meta-Llama-3-70B, meta-llama/Meta-Llama-3-70B-Instruct |
+| [Llama3.1](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3.1-8B, meta-llama/Meta-Llama-3.1-8B-Instruct, meta-llama/Meta-Llama-3.1-70B, meta-llama/Meta-Llama-3.1-70B-Instruct, meta-llama/Meta-Llama-3.1-405B, meta-llama/Meta-Llama-3.1-405B-Instruct, meta-llama/Llama-Guard-3-8B |
+| [Llama3.2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-3.2-1B, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.2-3B, meta-llama/Llama-3.2-3B-Instruct, meta-llama/Llama-Guard-3-1B |
+| [Llama3.3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-3.3-70B-Instruct |
+| [Baichuan](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan-7B, baichuan-inc/Baichuan-13B-Base, baichuan-inc/Baichuan-13B-Chat |
+| [Baichuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan2-7B-Base, baichuan-inc/Baichuan2-7B-Chat, baichuan-inc/Baichuan2-13B-Base, baichuan-inc/Baichuan2-13B-Chat |
+| [Bloom](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/bloom) | bigscience/bloom-560m, bigscience/bloom-560m-bf16, bigscience/bloom-1b1, bigscience/bloom-3b, bigscience/bloom-7b1, bigscience/bloomz-560m, bigscience/bloomz-1b1, bigscience/bloomz-3b, bigscience/bloomz-7b1-mt, bigscience/bloomz-7b1-p3, bigscience/bloomz-7b1, bellegroup/belle-7b-2m |
+| [ChatGLM](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm/) | THUDM/chatglm-6b, THUDM/chatglm-6b-v1.1 |
+| [ChatGLM2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm2-6b |
+| [ChatGLM3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm3-6b |
+| [DeepSeekV2](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V2, deepseek-ai/DeepSeek-V2-Chat, deepseek-ai/DeepSeek-V2-Lite, deepseek-ai/DeepSeek-V2-Lite-Chat, deepseek-ai/DeepSeek-Coder-V2-Base, deepseek-ai/DeepSeek-Coder-V2-Instruct, deepseek-ai/DeepSeek-Coder-V2-Lite-Base, deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct |
+| [DeepSeekV3](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V3, deepseek-ai/DeepSeek-V3-Base |
+| [DeepSeek-R1](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-R1, deepseek-ai/DeepSeek-R1-Zero, deepseek-ai/DeepSeek-R1-Distill-Llama-70B, deepseek-ai/DeepSeek-R1-Distill-Llama-8B, deepseek-ai/DeepSeek-R1-Distill-Qwen-14B, deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, deepseek-ai/DeepSeek-R1-Distill-Qwen-32B, deepseek-ai/DeepSeek-R1-Distill-Qwen-7B |
+| [Gemma](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/gemma) | google/gemma-7b, google/gemma-7b-it, google/gemma-2b, google/gemma-2b-it |
+| [Mistral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mistral) | mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-7B-v0.1 |
+| [Mixtral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mixtral) | mistralai/Mixtral-8x7B-Instruct-v0.1 |
+| [OPT](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/opt) | facebook/opt-125m, facebook/opt-350m, facebook/opt-1.3b, facebook/opt-2.7b, facebook/opt-6.7b, facebook/opt-13b, facebook/opt-30b, facebook/opt-66b, facebook/opt-iml-1.3b, opt-iml-max-1.3b |
+| [Qwen](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | qwen/qwen-7b, qwen/qwen-7b-chat, qwen/qwen-14b, qwen/qwen-14b-chat, qwen/qwen-72b, qwen/qwen-72b-chat, |
+| [Qwen1.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen1.5-0.5B, Qwen/Qwen1.5-0.5B-Chat, Qwen/Qwen1.5-1.8B, Qwen/Qwen1.5-1.8B-Chat, Qwen/Qwen1.5-4B, Qwen/Qwen1.5-4B-Chat, Qwen/Qwen1.5-7B, Qwen/Qwen1.5-7B-Chat, Qwen/Qwen1.5-14B, Qwen/Qwen1.5-14B-Chat, Qwen/Qwen1.5-32B, Qwen/Qwen1.5-32B-Chat, Qwen/Qwen1.5-72B, Qwen/Qwen1.5-72B-Chat, Qwen/Qwen1.5-110B, Qwen/Qwen1.5-110B-Chat, Qwen/Qwen1.5-MoE-A2.7B, Qwen/Qwen1.5-MoE-A2.7B-Chat |
+| [Qwen2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-0.5B, Qwen/Qwen2-0.5B-Instruct, Qwen/Qwen2-1.5B, Qwen/Qwen2-1.5B-Instruct, Qwen/Qwen2-7B, Qwen/Qwen2-7B-Instruct, Qwen/Qwen2-72B, Qwen/Qwen2-72B-Instruct, Qwen/Qwen2-57B-A14B, Qwen/Qwen2-57B-A14B-Instruct |
+| [Qwen2-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-Math-1.5B, Qwen/Qwen2-Math-1.5B-Instruct, Qwen/Qwen2-Math-7B, Qwen/Qwen2-Math-7B-Instruct, Qwen/Qwen2-Math-72B, Qwen/Qwen2-Math-72B-Instruct, Qwen/Qwen2-Math-RM-72B |
+| [Qwen2.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-0.5B, Qwen/Qwen2.5-0.5B-Instruct, Qwen/Qwen2.5-1.5B, Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-3B, Qwen/Qwen2.5-3B-Instruct, Qwen/Qwen2.5-7B, Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen2.5-14B, Qwen/Qwen2.5-14B-Instruct, Qwen/Qwen2.5-32B, Qwen/Qwen2.5-32B-Instruct, Qwen/Qwen2.5-72B, Qwen/Qwen2.5-72B-Instruct |
+| [Qwen2.5-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Math-1.5B, Qwen/Qwen2.5-Math-1.5B-Instruct, Qwen/Qwen2.5-Math-7B, Qwen/Qwen2.5-Math-7B-Instruct, Qwen/Qwen2.5-Math-72B, Qwen/Qwen2.5-Math-72B-Instruct, Qwen/Qwen2.5-Math-RM-72B |
+| [Qwen2.5-Coder](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Coder-1.5B, Qwen/Qwen2.5-Coder-1.5B-Instruct, Qwen/Qwen2.5-Coder-7B, Qwen/Qwen2.5-Coder-7B-Instruct |
+| [Yuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/yuan/) | IEITYuan/Yuan2-2B, IEITYuan/Yuan2-51B, IEITYuan/Yuan2-102B |
* 4D 并行和算子优化已支持 LLaMA 系列、Baichuan 系列、Bloom 系列、ChatGLM 系列、Gemma 系列、Mistral 系列、OPT 系列和 Qwen 系列,【LLM】模型4D 并行和算子支持列表如下:
@@ -130,19 +141,19 @@
| Model | Pretrain | SFT | LoRA | FlashMask | Prefix Tuning | DPO/SimPO/ORPO/KTO | RLHF | Mergekit | Quantization |
-|--------------------------------------------|:--------:|:---:|:----:|:---------:|:-------------:|:--------------:|:----:|:-----:|:------------:|
-| [Llama](./llm/config/llama) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [Qwen](./llm/config/qwen) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅ | 🚧 |
-| [Mixtral](./llm/config/mixtral) | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | 🚧 | ✅ | 🚧 |
-| [Mistral](./llm/config/mistral) | ✅ | ✅ | ✅ | 🚧 | ✅ | ✅ | 🚧 | ✅ | 🚧 |
-| [Baichuan/Baichuan2](./llm/config/llama) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅ | ✅ |
-| [ChatGLM-6B](./llm/config/chatglm) | ✅ | ✅ | ✅ | 🚧 | ✅ | 🚧 | 🚧 | ✅ | ✅ |
-| [ChatGLM2/ChatGLM3](./llm/config/chatglm2) | ✅ | ✅ | ✅ | 🚧 | ✅ | ✅ | 🚧 | ✅ | ✅ |
-| [Bloom](./llm/config/bloom) | ✅ | ✅ | ✅ | 🚧 | ✅ | 🚧 | 🚧 | ✅ | ✅ |
-| [GPT-3](./llm/config/gpt-3) | ✅ | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | 🚧 | ✅ | 🚧 |
-| [OPT](./llm/config/opt) | ✅ | ✅ | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | ✅ | 🚧 |
-| [Gemma](./llm/config/gemma) | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | 🚧 | ✅ | 🚧 |
-| [Yuan](./llm/config/yuan) | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | 🚧 | ✅ | 🚧 |
+|--------------------------------------------|:--------:|:---:|:----:|:---------:|:-------------:|:------------------:|:----:|:--------:|:------------:|
+| [Llama](./llm/config/llama) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [Qwen](./llm/config/qwen) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅ | 🚧 |
+| [Mixtral](./llm/config/mixtral) | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | 🚧 | ✅ | 🚧 |
+| [Mistral](./llm/config/mistral) | ✅ | ✅ | ✅ | 🚧 | ✅ | ✅ | 🚧 | ✅ | 🚧 |
+| [Baichuan/Baichuan2](./llm/config/llama) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅ | ✅ |
+| [ChatGLM-6B](./llm/config/chatglm) | ✅ | ✅ | ✅ | 🚧 | ✅ | 🚧 | 🚧 | ✅ | ✅ |
+| [ChatGLM2/ChatGLM3](./llm/config/chatglm2) | ✅ | ✅ | ✅ | 🚧 | ✅ | ✅ | 🚧 | ✅ | ✅ |
+| [Bloom](./llm/config/bloom) | ✅ | ✅ | ✅ | 🚧 | ✅ | 🚧 | 🚧 | ✅ | ✅ |
+| [GPT-3](./llm/config/gpt-3) | ✅ | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | 🚧 | ✅ | 🚧 |
+| [OPT](./llm/config/opt) | ✅ | ✅ | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | ✅ | 🚧 |
+| [Gemma](./llm/config/gemma) | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | 🚧 | ✅ | 🚧 |
+| [Yuan](./llm/config/yuan) | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | 🚧 | ✅ | 🚧 |
* [大模型推理](./llm/docs/predict/inference.md)已支持 LLaMA 系列、Qwen 系列、Mistral 系列、ChatGLM 系列、Bloom 系列和 Baichuan 系列,支持 Weight Only INT8及 INT4推理,支持 WAC(权重、激活、Cache KV)进行 INT8、FP8量化的推理,【LLM】模型推理支持列表如下:
| 模型名称/量化类型支持 | FP16/BF16 | WINT8 | WINT4 | INT8-A8W8 | FP8-A8W8 | INT8-A8W8C8 |
@@ -160,7 +171,7 @@
### 环境依赖
* python >= 3.8
-* paddlepaddle >= 3.0.0b0
+* paddlepaddle >= 3.0.0rc0
如果您尚未安装 PaddlePaddle,请参考 [飞桨官网](https://www.paddlepaddle.org.cn/) 进行安装。
@@ -205,7 +216,7 @@ wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwe
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.idx
cd .. # change folder to PaddleNLP/llm
# 如需使用use_fused_rms_norm=true,需要前往slm/model_zoo/gpt-3/external_ops安装fused_ln
-python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_pretrain.py ./config/llama/pretrain_argument.json --use_fused_rms_norm false
+python -u run_pretrain.py ./config/qwen/pretrain_argument_0p5b.json
```
### 大模型 SFT 精调
@@ -215,7 +226,7 @@ git clone https://github.com/PaddlePaddle/PaddleNLP.git && cd PaddleNLP # 如已
mkdir -p llm/data && cd llm/data
wget https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz && tar -zxvf AdvertiseGen.tar.gz
cd .. # change folder to PaddleNLP/llm
-python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_finetune.py ./config/llama/sft_argument.json
+python -u run_finetune.py ./config/qwen/sft_argument_0p5b.json
```
更多大模型全流程步骤,请参考[飞桨大模型套件](./llm)介绍。
@@ -230,7 +241,7 @@ dataset = load_dataset("ZHUI/alpaca_demo", split="train")
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT", device="gpu")
trainer = SFTTrainer(
args=training_args,
- model="Qwen/Qwen2.5-0.5B",
+ model="Qwen/Qwen2.5-0.5B-Instruct",
train_dataset=dataset,
)
trainer.train()
diff --git a/README_en.md b/README_en.md
index db0edf1ceeaa..d7934748379d 100644
--- a/README_en.md
+++ b/README_en.md
@@ -7,7 +7,7 @@
------------------------------------------------------------------------------------------
-
+
@@ -16,6 +16,7 @@
+
@@ -52,6 +53,9 @@ The fine-tuning algorithms are deeply integrated with zero-padding data streams
The high-performance inference module of the large model toolkit incorporates dynamic insertion and operator fusion strategies throughout the entire process, greatly accelerating parallel inference speed. The underlying implementation details are encapsulated, enabling out-of-the-box high-performance parallel inference capabilities.
+## Documentation
+For detailed documentation, visit the [PaddleNLP Documentation](https://paddlenlp.readthedocs.io/).
+
------------------------------------------------------------------------------------------
## Support Models
@@ -68,7 +72,7 @@ Detailed list 👉 [Supported Model List](https://github.com/PaddlePaddle/Paddle
### Pip Installation
```shell
-pip install --upgrade paddlenlp==3.0.0b2
+pip install --upgrade paddlenlp==3.0.0b3
```
or you can install the latest develop branch code with the following command:
diff --git a/csrc/README.md b/csrc/README.md
index 02bd4a372e46..24fe14da6756 100644
--- a/csrc/README.md
+++ b/csrc/README.md
@@ -1,6 +1,9 @@
-# PaddleNLP 自定义 OP
+# PaddleNLP 大模型高性能自定义推理算子
-此文档介绍如何编译安装 PaddleNLP 自定义 OP。
+此文档介绍如何编译安装 PaddleNLP 大模型高性能自定义推理算子的安装教程。
+
+使用这些高性能算子,可以大幅提升大模型推理速度。
+大模型推理相关教程详见[此处](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/README.md#6-%E6%8E%A8%E7%90%86)。
## 安装 C++ 依赖
diff --git a/csrc/gpu/append_attention.cu b/csrc/gpu/append_attention.cu
index d24a20e48d11..d6f3efbbf3df 100644
--- a/csrc/gpu/append_attention.cu
+++ b/csrc/gpu/append_attention.cu
@@ -56,6 +56,7 @@ std::vector AppendAttentionKernel(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
@@ -97,13 +98,13 @@ std::vector AppendAttentionKernel(
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
- {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
+ {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
paddle::DataType::INT8,
qkv.place());
}
else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
- {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
+ {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
}else{
@@ -111,7 +112,7 @@ std::vector AppendAttentionKernel(
}
} else {
fmha_out = GetEmptyTensor(
- {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
+ {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
D,
qkv.place());
}
@@ -203,6 +204,7 @@ std::vector AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -240,6 +242,7 @@ std::vector AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -282,6 +285,7 @@ std::vector AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -428,6 +432,7 @@ std::vector AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -465,6 +470,7 @@ std::vector AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -508,6 +514,7 @@ std::vector AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -565,6 +572,7 @@ std::vector AppendAttention(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
@@ -578,9 +586,10 @@ std::vector AppendAttention(
meta_data.token_nums = qkv_dims[0];
meta_data.kv_num_heads = key_cache_dims[1];
meta_data.head_dims = key_cache_dims[3];
- const int total_num_head =
- qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims;
- meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads;
+ meta_data.head_dims_v = value_cache.dims()[3];
+ const int q_hidden_size =
+ qkv_dims[qkv_dims.size() - 1] - meta_data.kv_num_heads * (meta_data.head_dims + meta_data.head_dims_v);
+ meta_data.q_num_heads = q_hidden_size / meta_data.head_dims;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
@@ -626,6 +635,7 @@ std::vector AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -672,6 +682,7 @@ std::vector AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -719,6 +730,7 @@ std::vector AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -764,6 +776,7 @@ std::vector AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -821,10 +834,12 @@ std::vector> AppendAttentionInferShape(
const paddle::optional>& out_linear_smooths_shape) {
const int token_num = qkv_shape[0];
const int kv_num_heads = key_cache_shape[1];
- const int head_dim = key_cache_shape[3];
- const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
- const int num_heads = total_num_head - 2 * kv_num_heads;
- return {{token_num, num_heads * head_dim}, qkv_shape};
+ const int head_dim_qk = key_cache_shape[3];
+ const int head_dim_v = value_cache_shape[3];
+ const int q_hidden_size =
+ qkv_shape[qkv_shape.size() - 1] - kv_num_heads * (head_dim_qk + head_dim_v);
+ const int num_heads = q_hidden_size / head_dim_qk;
+ return {{token_num, num_heads * head_dim_v}, qkv_shape};
}
std::vector AppendAttentionInferDtype(
@@ -865,6 +880,7 @@ std::vector AppendAttentionInferDtype(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
@@ -941,6 +957,7 @@ PD_BUILD_OP(append_attention)
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"max_input_length: int",
+ "softmax_scale: float",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
diff --git a/csrc/gpu/append_attn/append_attention_c16_impl.cuh b/csrc/gpu/append_attn/append_attention_c16_impl.cuh
index 3b08d0a85dbc..8b75fa13cdca 100644
--- a/csrc/gpu/append_attn/append_attention_c16_impl.cuh
+++ b/csrc/gpu/append_attn/append_attention_c16_impl.cuh
@@ -23,15 +23,17 @@ template
__global__ void multi_query_append_attention_kernel(
- T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
+ T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
T *__restrict__ cache_v,
@@ -46,7 +48,7 @@ __global__ void multi_query_append_attention_kernel(
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
- const float scale,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -57,7 +59,9 @@ __global__ void multi_query_append_attention_kernel(
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
- constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b();
+ constexpr uint32_t num_vecs_per_head_qk =
+ HEAD_DIM_QK / num_elems_per_128b();
+ constexpr uint32_t num_vecs_per_head_v = HEAD_DIM_V / num_elems_per_128b();
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
const uint32_t kv_num_heads = gridDim.z;
const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE;
@@ -104,25 +108,30 @@ __global__ void multi_query_append_attention_kernel(
extern __shared__ uint8_t smem[];
float s_frag[num_frags_x][num_frags_z][8];
- float o_frag[num_frags_x][num_frags_y][8];
+ float o_frag[num_frags_x][num_frags_y_v][8];
float m_frag[num_frags_x][2];
float d_frag[num_frags_x][2];
- init_states(o_frag, m_frag, d_frag);
-
- const uint32_t q_n_stride = q_num_heads * HEAD_DIM;
- const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM;
- const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
- const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
- const uint32_t kv_b_stride = HEAD_DIM;
+ init_states(o_frag, m_frag, d_frag);
+
+ const uint32_t q_n_stride = q_num_heads * HEAD_DIM_V;
+ const uint32_t q_ori_n_stride = q_num_heads * HEAD_DIM_QK +
+ kv_num_heads * HEAD_DIM_QK +
+ kv_num_heads * HEAD_DIM_V;
+ const uint32_t k_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM_QK;
+ const uint32_t k_h_stride = BLOCK_SIZE * HEAD_DIM_QK;
+ const uint32_t k_b_stride = HEAD_DIM_QK;
+ const uint32_t v_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM_V;
+ const uint32_t v_h_stride = BLOCK_SIZE * HEAD_DIM_V;
+ const uint32_t v_b_stride = HEAD_DIM_V;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_base_seq_id_this_block =
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
- q_head_idx * HEAD_DIM +
+ q_head_idx * HEAD_DIM_QK +
tid % 8 * num_elems_per_128b();
const uint32_t o_offset = q_start_seq_id * q_n_stride +
- q_head_idx * HEAD_DIM +
+ q_head_idx * HEAD_DIM_V +
tid % 8 * num_elems_per_128b();
T *q_base_ptr = q + q_offset;
T *o_base_ptr_T = nullptr;
@@ -130,13 +139,13 @@ __global__ void multi_query_append_attention_kernel(
if constexpr (partition_kv) {
if (ENABLE_PREFILL) {
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
- chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
+ chunk_idx * q_n_stride + q_head_idx * HEAD_DIM_V +
tid % 8 * num_elems_per_128b();
} else {
o_base_ptr_T =
tmp_workspace +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
- chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
+ chunk_idx * q_n_stride + q_head_idx * HEAD_DIM_V +
tid % 8 * num_elems_per_128b();
}
} else {
@@ -144,24 +153,42 @@ __global__ void multi_query_append_attention_kernel(
}
smem_t qo_smem(smem);
- uint32_t q_smem_offset_r = smem_t::get_permuted_offset(
+ uint32_t q_smem_offset_r = smem_t::get_permuted_offset(
wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16
- load_q_global_smem(
+ load_q_global_smem(
q_base_ptr,
&qo_smem,
q_base_seq_id_this_block,
q_end,
q_ori_n_stride,
- HEAD_DIM);
+ HEAD_DIM_QK);
commit_group();
wait_group<0>();
__syncthreads();
+#ifdef DEBUG_PERCISION
+ if (tid == 0 && threadIdx.y == 0 && blockIdx.z == 0 && blockIdx.x == 0) {
+ printf("q_smem(%d * 192个bfloat16):\n", 4 * num_frags_x * 16);
+ // const uint32_t k_num = num_frags_z * 64 * HEAD_DIM / 2 * sizeof(CacheT);
+ T *q_smem_t = reinterpret_cast(qo_smem.base);
+ for (uint32_t i = 0; i < 4 * num_frags_x * 16; ++i) {
+ printf("q_smem[%d]:", (int)i);
+ for (uint32_t j = 0; j < HEAD_DIM_QK / 8; ++j) {
+ printf("[");
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.2f ", (float)q_smem_t[i * HEAD_DIM_QK + j * 8 + k]);
+ }
+ printf("]");
+ }
+ printf("\n");
+ }
+ }
+ __syncthreads();
+#endif
+ q_smem_inplace_multiply_sm_scale(
+ &qo_smem, softmax_scale);
- q_smem_inplace_multiply_sm_scale(&qo_smem,
- scale);
-
- smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
- v_smem(smem + (NUM_WARPS * num_frags_x + num_frags_z) * 16 * HEAD_DIM *
+ smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM_QK * sizeof(T)),
+ v_smem(smem + (NUM_WARPS * num_frags_x + num_frags_z) * 16 * HEAD_DIM_QK *
sizeof(T));
@@ -182,50 +209,55 @@ __global__ void multi_query_append_attention_kernel(
chunk_start)))
: chunk_len) /
(num_frags_z * 16);
- uint32_t k_smem_offset_r = smem_t::get_permuted_offset(
+ uint32_t k_smem_offset_r = smem_t::get_permuted_offset(
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
uint32_t v_smem_offset_r =
- smem_t::get_permuted_offset(tid % 16, tid / 16);
+ smem_t::get_permuted_offset(tid % 16, tid / 16);
- uint32_t kv_smem_offset_w = smem_t::get_permuted_offset(
+ uint32_t k_smem_offset_w = smem_t::get_permuted_offset(
+ wid * 4 + tid / 8, tid % 8);
+ uint32_t v_smem_offset_w = smem_t::get_permuted_offset(
wid * 4 + tid / 8, tid % 8);
uint32_t kv_idx_base = chunk_start;
int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]);
- const uint32_t const_offset = kv_head_idx * kv_h_stride +
- (wid * 4 + tid / 8) * kv_b_stride +
- tid % 8 * num_elems_per_128b();
- T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset;
- T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset;
+ const uint32_t const_offset_k = kv_head_idx * k_h_stride +
+ (wid * 4 + tid / 8) * k_b_stride +
+ tid % 8 * num_elems_per_128b();
+ const uint32_t const_offset_v = kv_head_idx * v_h_stride +
+ (wid * 4 + tid / 8) * v_b_stride +
+ tid % 8 * num_elems_per_128b();
+ T *cache_k_now = cache_k + block_id * k_n_stride + const_offset_k;
+ T *cache_v_now = cache_v + block_id * v_n_stride + const_offset_v;
produce_kv_blockwise(k_smem,
- &kv_smem_offset_w,
+ &k_smem_offset_w,
&cache_k_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ k_n_stride,
+ k_h_stride,
+ k_b_stride,
kv_idx_base,
chunk_end);
commit_group();
produce_kv_blockwise(v_smem,
- &kv_smem_offset_w,
+ &v_smem_offset_w,
&cache_v_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ v_n_stride,
+ v_h_stride,
+ v_b_stride,
kv_idx_base,
chunk_end);
commit_group();
@@ -233,10 +265,45 @@ __global__ void multi_query_append_attention_kernel(
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
wait_group<1>();
__syncthreads();
-
+#ifdef DEBUG_PERCISION
+ if (tid == 0 && threadIdx.y == 0 && blockIdx.z == 0 && blockIdx.x == 0) {
+ printf("k_smem(%d * 192个bfloat16):\n", num_frags_z * 16);
+ // const uint32_t k_num = num_frags_z * 64 * HEAD_DIM / 2 *
+ // sizeof(CacheT);
+ T *k_smem_t = reinterpret_cast(k_smem.base);
+ for (uint32_t i = 0; i < num_frags_z * 16; ++i) {
+ printf("k_smem[%d]:", (int)i);
+ for (uint32_t j = 0; j < HEAD_DIM_QK / 8; ++j) {
+ printf("[");
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.2f ", (float)k_smem_t[i * HEAD_DIM_QK + j * 8 + k]);
+ }
+ printf("]");
+ }
+ printf("\n");
+ }
+ }
+ __syncthreads();
+#endif
// s = qk
- compute_qk(
+ compute_qk(
&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
+#ifdef DEBUG_PERCISION
+ __syncthreads();
+ if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.z == 0 &&
+ blockIdx.x == 0) {
+ for (uint32_t i = 0; i < num_frags_x; ++i) {
+ for (uint32_t j = 0; j < num_frags_z; ++j) {
+ printf("s_frag[%d][%d]:\n", i, j);
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.4f ", s_frag[i][j][k]);
+ }
+ printf("\n");
+ }
+ }
+ }
+ __syncthreads();
+#endif
// mask according to kv_idx and q_idx
if (iter >= mask_check_iteration) {
mask_s(q_base_seq_id_this_block,
kv_idx_base,
q_len,
@@ -255,7 +322,7 @@ __global__ void multi_query_append_attention_kernel(
}
// update m,d
- update_mdo_states(
+ update_mdo_states(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();
@@ -264,43 +331,77 @@ __global__ void multi_query_append_attention_kernel(
if (block_id < 0) {
block_id = 0;
}
- cache_k_now = cache_k + block_id * kv_n_stride + const_offset;
+ cache_k_now = cache_k + block_id * k_n_stride + const_offset_k;
produce_kv_blockwise(k_smem,
- &kv_smem_offset_w,
+ &k_smem_offset_w,
&cache_k_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ k_n_stride,
+ k_h_stride,
+ k_b_stride,
kv_idx_base,
chunk_end);
commit_group();
wait_group<1>();
__syncthreads();
-
+#ifdef DEBUG_PERCISION
+ if (tid == 0 && threadIdx.y == 0 && blockIdx.z == 0 && blockIdx.x == 0) {
+ printf("v_smem(%d * 128个bfloat16):\n", num_frags_z * 16);
+ // const uint32_t k_num = num_frags_z * 64 * HEAD_DIM / 2 *
+ // sizeof(CacheT);
+ T *v_smem_t = reinterpret_cast(v_smem.base);
+ for (uint32_t i = 0; i < num_frags_z * 16; ++i) {
+ printf("v_smem[%d]:", (int)i);
+ for (uint32_t j = 0; j < HEAD_DIM_V / 8; ++j) {
+ printf("[");
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.2f ", (float)v_smem_t[i * HEAD_DIM_V + j * 8 + k]);
+ }
+ printf("]");
+ }
+ printf("\n");
+ }
+ }
+ __syncthreads();
+#endif
// compute sfm*v
- compute_sfm_v(
+ compute_sfm_v(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag);
-
+#ifdef DEBUG_PERCISION
+ __syncthreads();
+ if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.z == 0 &&
+ blockIdx.x == 0) {
+ for (uint32_t i = 0; i < num_frags_x; ++i) {
+ for (uint32_t j = 0; j < num_frags_y_v; ++j) {
+ printf("o_frag[%d][%d]:\n", i, j);
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.4f ", s_frag[i][j][k]);
+ }
+ printf("\n");
+ }
+ }
+ }
__syncthreads();
- cache_v_now = cache_v + block_id * kv_n_stride + const_offset;
+#endif
+ __syncthreads();
+ cache_v_now = cache_v + block_id * v_n_stride + const_offset_v;
produce_kv_blockwise(v_smem,
- &kv_smem_offset_w,
+ &v_smem_offset_w,
&cache_v_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ v_n_stride,
+ v_h_stride,
+ v_b_stride,
kv_idx_base,
chunk_end);
commit_group();
@@ -309,12 +410,28 @@ __global__ void multi_query_append_attention_kernel(
__syncthreads();
if constexpr (!partition_kv) {
- normalize_d(o_frag, d_frag);
+ normalize_d(o_frag, d_frag);
+ }
+#ifdef DEBUG_PERCISION
+ __syncthreads();
+ if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.z == 0 &&
+ blockIdx.x == 0) {
+ for (uint32_t i = 0; i < num_frags_x; ++i) {
+ for (uint32_t j = 0; j < num_frags_y_v; ++j) {
+ printf("o_frag[%d][%d]:\n", i, j);
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.4f ", s_frag[i][j][k]);
+ }
+ printf("\n");
+ }
+ }
}
+ __syncthreads();
+#endif
if constexpr (partition_kv) {
write_o_reg_gmem_shift_smooth_quant(
o_frag,
&qo_smem,
@@ -328,11 +445,11 @@ __global__ void multi_query_append_attention_kernel(
in_scale,
q_len,
partition_kv ? q_n_stride * num_chunks : q_n_stride,
- HEAD_DIM);
+ HEAD_DIM_V);
} else {
write_o_reg_gmem_shift_smooth_quant(
o_frag,
&qo_smem,
@@ -346,7 +463,7 @@ __global__ void multi_query_append_attention_kernel(
in_scale,
q_len,
partition_kv ? q_n_stride * num_chunks : q_n_stride,
- HEAD_DIM);
+ HEAD_DIM_V);
}
@@ -387,15 +504,17 @@ template
__global__ void multi_query_append_attention_warp1_4_kernel(
- T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
+ T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
T *__restrict__ cache_v,
@@ -410,7 +529,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
- const float scale,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -421,7 +540,9 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
- constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b();
+ constexpr uint32_t num_vecs_per_head_qk =
+ HEAD_DIM_QK / num_elems_per_128b();
+ constexpr uint32_t num_vecs_per_head_v = HEAD_DIM_V / num_elems_per_128b();
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
@@ -467,24 +588,29 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
extern __shared__ uint8_t smem[];
float s_frag[num_frags_x][num_frags_z][8];
- float o_frag[num_frags_x][num_frags_y][8];
+ float o_frag[num_frags_x][num_frags_y_v][8];
float m_frag[num_frags_x][2];
float d_frag[num_frags_x][2];
- init_states(o_frag, m_frag, d_frag);
-
- const uint32_t q_n_stride = q_num_heads * HEAD_DIM;
- const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM;
- const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
- const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
- const uint32_t kv_b_stride = HEAD_DIM;
+ init_states(o_frag, m_frag, d_frag);
+
+ const uint32_t q_n_stride = q_num_heads * HEAD_DIM_V;
+ const uint32_t q_ori_n_stride = q_num_heads * HEAD_DIM_QK +
+ kv_num_heads * HEAD_DIM_QK +
+ kv_num_heads * HEAD_DIM_V;
+ const uint32_t k_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM_QK;
+ const uint32_t k_h_stride = BLOCK_SIZE * HEAD_DIM_QK;
+ const uint32_t k_b_stride = HEAD_DIM_QK;
+ const uint32_t v_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM_V;
+ const uint32_t v_h_stride = BLOCK_SIZE * HEAD_DIM_V;
+ const uint32_t v_b_stride = HEAD_DIM_V;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
- q_head_idx * HEAD_DIM +
+ q_head_idx * HEAD_DIM_QK +
tid % 8 * num_elems_per_128b();
const uint32_t o_offset = q_start_seq_id * q_n_stride +
- q_head_idx * HEAD_DIM +
+ q_head_idx * HEAD_DIM_V +
tid % 8 * num_elems_per_128b();
T *q_base_ptr = q + q_offset;
T *o_base_ptr_T = nullptr;
@@ -494,41 +620,59 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
} else {
if (ENABLE_PREFILL) {
o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride +
- chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
+ chunk_idx * q_n_stride + q_head_idx * HEAD_DIM_V +
tid % 8 * num_elems_per_128b();
} else {
o_base_ptr_T =
tmp_workspace +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
- chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
+ chunk_idx * q_n_stride + q_head_idx * HEAD_DIM_V +
tid % 8 * num_elems_per_128b();
}
}
smem_t qo_smem(smem);
- uint32_t q_smem_offset_r = smem_t::get_permuted_offset(
+ uint32_t q_smem_offset_r = smem_t::get_permuted_offset(
tid % 16, tid / 16); // 16 * 16
load_q_global_smem_multi_warps(q_base_ptr,
&qo_smem,
q_base_seq_id_this_block,
q_end,
q_ori_n_stride,
- HEAD_DIM);
+ HEAD_DIM_QK);
commit_group();
wait_group<0>();
__syncthreads();
+#ifdef DEBUG_PERCISION_DEC
+ if (tid == 0 && threadIdx.y == 0 && blockIdx.z == 0 && blockIdx.x == 0) {
+ printf("q_smem(%d * 192个bfloat16):\n", num_frags_x * 16);
+ // const uint32_t k_num = num_frags_z * 64 * HEAD_DIM / 2 * sizeof(CacheT);
+ T *q_smem_t = reinterpret_cast(qo_smem.base);
+ for (uint32_t i = 0; i < 4 * num_frags_x * 16; ++i) {
+ printf("q_smem[%d]:", (int)i);
+ for (uint32_t j = 0; j < HEAD_DIM_QK / 8; ++j) {
+ printf("[");
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.2f ", (float)q_smem_t[i * HEAD_DIM_QK + j * 8 + k]);
+ }
+ printf("]");
+ }
+ printf("\n");
+ }
+ }
+ __syncthreads();
+#endif
+ q_smem_inplace_multiply_sm_scale_multi_warps(
+ &qo_smem, softmax_scale);
- q_smem_inplace_multiply_sm_scale_multi_warps(
- &qo_smem, scale);
-
- smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
- v_smem(smem + (num_frags_x + NUM_WARP_KV * num_frags_z) * 16 * HEAD_DIM *
- sizeof(T));
+ smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM_QK * sizeof(T)),
+ v_smem(smem + (num_frags_x + NUM_WARP_KV * num_frags_z) * 16 *
+ HEAD_DIM_QK * sizeof(T));
const uint32_t num_iterations = div_up(
CAUSAL
@@ -548,34 +692,39 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
- uint32_t k_smem_offset_r = smem_t::get_permuted_offset(
+ uint32_t k_smem_offset_r = smem_t::get_permuted_offset(
wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
- uint32_t v_smem_offset_r = smem_t::get_permuted_offset(
+ uint32_t v_smem_offset_r = smem_t::get_permuted_offset(
wid * num_frags_z * 16 + tid % 16, tid / 16);
- uint32_t kv_smem_offset_w = smem_t::get_permuted_offset(
+ uint32_t k_smem_offset_w = smem_t::get_permuted_offset(
+ wid * 4 + tid / 8, tid % 8);
+ uint32_t v_smem_offset_w = smem_t::get_permuted_offset(
wid * 4 + tid / 8, tid % 8);
uint32_t kv_idx_base = chunk_start;
int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]);
- const uint32_t const_offset = kv_head_idx * kv_h_stride +
- (wid * 4 + tid / 8) * kv_b_stride +
- tid % 8 * num_elems_per_128b();
- T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset;
- T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset;
+ const uint32_t const_offset_k = kv_head_idx * k_h_stride +
+ (wid * 4 + tid / 8) * k_b_stride +
+ tid % 8 * num_elems_per_128b();
+ const uint32_t const_offset_v = kv_head_idx * v_h_stride +
+ (wid * 4 + tid / 8) * v_b_stride +
+ tid % 8 * num_elems_per_128b();
+ T *cache_k_now = cache_k + block_id * k_n_stride + const_offset_k;
+ T *cache_v_now = cache_v + block_id * v_n_stride + const_offset_v;
produce_kv_blockwise(k_smem,
- &kv_smem_offset_w,
+ &k_smem_offset_w,
&cache_k_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ k_n_stride,
+ k_h_stride,
+ k_b_stride,
kv_idx_base,
chunk_end);
commit_group();
@@ -583,15 +732,15 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
produce_kv_blockwise(v_smem,
- &kv_smem_offset_w,
+ &v_smem_offset_w,
&cache_v_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ v_n_stride,
+ v_h_stride,
+ v_b_stride,
kv_idx_base,
chunk_end);
commit_group();
@@ -600,10 +749,45 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
wait_group<1>();
__syncthreads();
-
+#ifdef DEBUG_PERCISION_DEC
+ if (tid == 0 && threadIdx.y == 0 && blockIdx.z == 0 && blockIdx.x == 0) {
+ printf("k_smem(%d * 192个bfloat16):\n", 4 * num_frags_z * 16);
+ // const uint32_t k_num = num_frags_z * 64 * HEAD_DIM / 2 *
+ // sizeof(CacheT);
+ T *k_smem_t = reinterpret_cast(k_smem.base);
+ for (uint32_t i = 0; i < num_frags_z * 16; ++i) {
+ printf("k_smem[%d]:", (int)i);
+ for (uint32_t j = 0; j < HEAD_DIM_QK / 8; ++j) {
+ printf("[");
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.2f ", (float)k_smem_t[i * HEAD_DIM_QK + j * 8 + k]);
+ }
+ printf("]");
+ }
+ printf("\n");
+ }
+ }
+ __syncthreads();
+#endif
// s = qk
- compute_qk(
+ compute_qk(
&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
+#ifdef DEBUG_PERCISION_DEC
+ __syncthreads();
+ if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.z == 0 &&
+ blockIdx.x == 0) {
+ for (uint32_t i = 0; i < num_frags_x; ++i) {
+ for (uint32_t j = 0; j < num_frags_z; ++j) {
+ printf("s_frag[%d][%d]:\n", i, j);
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.4f ", s_frag[i][j][k]);
+ }
+ printf("\n");
+ }
+ }
+ }
+ __syncthreads();
+#endif
// mask according to kv_idx and q_idx
if (iter >= mask_check_iteration) {
mask_s(q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
@@ -622,7 +806,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
}
// update m,d
- update_mdo_states(
+ update_mdo_states(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();
@@ -631,43 +815,77 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
if (block_id < 0) {
block_id = 0;
}
- cache_k_now = cache_k + block_id * kv_n_stride + const_offset;
+ cache_k_now = cache_k + block_id * k_n_stride + const_offset_k;
produce_kv_blockwise(k_smem,
- &kv_smem_offset_w,
+ &k_smem_offset_w,
&cache_k_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ k_n_stride,
+ k_h_stride,
+ k_b_stride,
kv_idx_base,
chunk_end);
commit_group();
wait_group<1>();
__syncthreads();
-
+#ifdef DEBUG_PERCISION_DEC
+ if (tid == 0 && threadIdx.y == 0 && blockIdx.z == 0 && blockIdx.x == 0) {
+ printf("v_smem(%d * 128个bfloat16):\n", 4 * num_frags_z * 16);
+ // const uint32_t k_num = num_frags_z * 64 * HEAD_DIM / 2 *
+ // sizeof(CacheT);
+ T *v_smem_t = reinterpret_cast(v_smem.base);
+ for (uint32_t i = 0; i < num_frags_z * 16; ++i) {
+ printf("v_smem[%d]:", (int)i);
+ for (uint32_t j = 0; j < HEAD_DIM_V / 8; ++j) {
+ printf("[");
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.2f ", (float)v_smem_t[i * HEAD_DIM_V + j * 8 + k]);
+ }
+ printf("]");
+ }
+ printf("\n");
+ }
+ }
+ __syncthreads();
+#endif
// compute sfm*v
- compute_sfm_v(
+ compute_sfm_v(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag);
__syncthreads();
-
- cache_v_now = cache_v + block_id * kv_n_stride + const_offset;
+#ifdef DEBUG_PERCISION_DEC
+ __syncthreads();
+ if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.z == 0 &&
+ blockIdx.x == 0) {
+ for (uint32_t i = 0; i < num_frags_x; ++i) {
+ for (uint32_t j = 0; j < num_frags_y_v; ++j) {
+ printf("o_frag[%d][%d]:\n", i, j);
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.4f ", s_frag[i][j][k]);
+ }
+ printf("\n");
+ }
+ }
+ }
+ __syncthreads();
+#endif
+ cache_v_now = cache_v + block_id * v_n_stride + const_offset_v;
produce_kv_blockwise(v_smem,
- &kv_smem_offset_w,
+ &v_smem_offset_w,
&cache_v_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ v_n_stride,
+ v_h_stride,
+ v_b_stride,
kv_idx_base,
chunk_end);
commit_group();
@@ -675,19 +893,34 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
wait_group<0>();
__syncthreads();
- merge_block_res_v2(
+ merge_block_res_v2(
o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid);
if (num_chunks_this_seq <= 1) {
- normalize_d(o_frag, d_frag);
+ normalize_d(o_frag, d_frag);
}
-
+#ifdef DEBUG_PERCISION_DEC
+ __syncthreads();
+ if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.z == 0 &&
+ blockIdx.x == 0) {
+ for (uint32_t i = 0; i < num_frags_x; ++i) {
+ for (uint32_t j = 0; j < num_frags_y_v; ++j) {
+ printf("o_frag[%d][%d]:\n", i, j);
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.4f ", s_frag[i][j][k]);
+ }
+ printf("\n");
+ }
+ }
+ }
+ __syncthreads();
+#endif
// write o
// [num_frags_x, 16, num_frags_y, 16]
if (num_chunks_this_seq <= 1) {
write_o_reg_gmem_multi_warps_shift_smooth_quant(
o_frag,
&qo_smem,
@@ -701,11 +934,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
in_scale,
q_len,
q_n_stride,
- HEAD_DIM);
+ HEAD_DIM_V);
} else {
write_o_reg_gmem_multi_warps_shift_smooth_quant(
o_frag,
&qo_smem,
@@ -719,7 +952,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
in_scale,
q_len,
q_n_stride * num_chunks,
- HEAD_DIM);
+ HEAD_DIM_V);
}
if (num_chunks_this_seq > 1) {
@@ -757,7 +990,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
template ;
if (smem_size >= 48 * 1024) {
@@ -853,11 +1090,13 @@ void MultiQueryAppendAttention(
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
- HEAD_DIM,
+ HEAD_DIM_QK,
+ HEAD_DIM_V,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
- num_frags_y,
+ num_frags_y_qk,
+ num_frags_y_v,
OUT_NV_TYPE,
ENABLE_PREFILL>;
if (smem_size >= 48 * 1024) {
@@ -885,7 +1124,7 @@ void MultiQueryAppendAttention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -899,9 +1138,10 @@ void MultiQueryAppendAttention(
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (ENABLE_PREFILL) {
- tmp_workspace = allocator->Allocate(
- phi::SizeOf(qkv.dtype()) *
- static_cast(token_num * num_chunks * num_heads * HEAD_DIM));
+ tmp_workspace =
+ allocator->Allocate(phi::SizeOf(qkv.dtype()) *
+ static_cast(token_num * num_chunks *
+ num_heads * HEAD_DIM_V));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast(token_num * num_chunks * num_heads));
@@ -912,7 +1152,7 @@ void MultiQueryAppendAttention(
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast(speculate_max_draft_token_num * bsz *
- num_chunks * num_heads * HEAD_DIM));
+ num_chunks * num_heads * HEAD_DIM_V));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast(speculate_max_draft_token_num * bsz *
@@ -942,7 +1182,7 @@ void MultiQueryAppendAttention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -955,14 +1195,14 @@ void MultiQueryAppendAttention(
// merge
constexpr int vec_size = num_elems_per_128b();
if (is_decoder) {
- constexpr int blockx = HEAD_DIM / vec_size;
+ constexpr int blockx = HEAD_DIM_V / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_decoder_kernel
<<>>(
@@ -987,9 +1227,9 @@ void MultiQueryAppendAttention(
num_chunks,
num_heads,
chunk_size,
- HEAD_DIM);
+ HEAD_DIM_V);
} else {
- constexpr int blockx = HEAD_DIM / vec_size;
+ constexpr int blockx = HEAD_DIM_V / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads); // 128k is too large
@@ -997,7 +1237,7 @@ void MultiQueryAppendAttention(
merge_multi_chunks_v2_kernel
<<>>(
@@ -1022,7 +1262,7 @@ void MultiQueryAppendAttention(
num_chunks,
num_heads,
chunk_size,
- HEAD_DIM,
+ HEAD_DIM_V,
token_num,
speculate_max_draft_token_num);
}
@@ -1030,8 +1270,9 @@ void MultiQueryAppendAttention(
} else {
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV;
constexpr uint32_t smem_size =
- (num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * HEAD_DIM *
- sizeof(T);
+ ((num_frags_x + NUM_WARP_KV * num_frags_z) * HEAD_DIM_QK +
+ NUM_WARP_KV * num_frags_z * HEAD_DIM_V) *
+ 16 * sizeof(T);
auto split_kv_kernel =
multi_query_append_attention_warp1_4_kernel;
if (smem_size >= 48 * 1024) {
@@ -1074,11 +1317,13 @@ void MultiQueryAppendAttention(
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
- HEAD_DIM,
+ HEAD_DIM_QK,
+ HEAD_DIM_V,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
- num_frags_y,
+ num_frags_y_qk,
+ num_frags_y_v,
OUT_NV_TYPE,
ENABLE_PREFILL>;
if (smem_size >= 48 * 1024) {
@@ -1106,7 +1351,7 @@ void MultiQueryAppendAttention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1121,7 +1366,7 @@ void MultiQueryAppendAttention(
if (is_decoder) {
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
- static_cast(bsz * num_chunks * num_heads * HEAD_DIM));
+ static_cast(bsz * num_chunks * num_heads * HEAD_DIM_V));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast(bsz * num_chunks * num_heads));
@@ -1133,7 +1378,7 @@ void MultiQueryAppendAttention(
tmp_workspace =
allocator->Allocate(phi::SizeOf(qkv.dtype()) *
static_cast(token_num * num_chunks *
- num_heads * HEAD_DIM));
+ num_heads * HEAD_DIM_V));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast(token_num * num_chunks * num_heads));
@@ -1144,7 +1389,7 @@ void MultiQueryAppendAttention(
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast(speculate_max_draft_token_num * bsz *
- num_chunks * num_heads * HEAD_DIM));
+ num_chunks * num_heads * HEAD_DIM_V));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast(speculate_max_draft_token_num * bsz *
@@ -1174,7 +1419,7 @@ void MultiQueryAppendAttention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1188,14 +1433,14 @@ void MultiQueryAppendAttention(
// merge
constexpr int vec_size = num_elems_per_128b();
if (is_decoder) {
- constexpr int blockx = HEAD_DIM / vec_size;
+ constexpr int blockx = HEAD_DIM_V / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_decoder_kernel
<<>>(
@@ -1220,17 +1465,16 @@ void MultiQueryAppendAttention(
num_chunks,
num_heads,
chunk_size,
- HEAD_DIM);
+ HEAD_DIM_V);
} else {
- constexpr int blockx = HEAD_DIM / vec_size;
+ constexpr int blockx = HEAD_DIM_V / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
- dim3 grids_merge(min(sm_count * 4, token_num),
- num_heads);
+ dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel
<<>>(
@@ -1255,7 +1499,7 @@ void MultiQueryAppendAttention(
num_chunks,
num_heads,
chunk_size,
- HEAD_DIM,
+ HEAD_DIM_V,
token_num,
speculate_max_draft_token_num);
}
@@ -1265,37 +1509,39 @@ void MultiQueryAppendAttention(
template
void CascadeAppendAttentionC16Kernel(
- const AppendAttnMetaData& meta_data,
- const paddle::Tensor& qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
- const paddle::Tensor&
- cache_k, // [max_block_num, num_heads, block_size, head_dim]
- const paddle::Tensor&
- cache_v, // [max_block_num, num_heads, head_dim, block_size]
- const paddle::optional& attn_mask,
- const paddle::optional&
- cache_k_scale, // [num_kv_heads, head_dim]
- const paddle::optional&
- cache_v_scale, // [num_kv_heads, head_dim]
- const paddle::optional&
- cache_k_zp, // [num_kv_heads, head_dim]
- const paddle::optional&
- cache_v_zp, // [num_kv_heads, head_dim]
- const paddle::optional&
- shift_bias, // [num_kv_heads, head_dim]
- const paddle::optional&
- smooth_weight, // [num_kv_heads, head_dim]
- const paddle::Tensor& seq_lens_q,
- const paddle::Tensor& seq_lens_kv,
- const paddle::Tensor& seq_lens_encoder,
- const paddle::Tensor& padding_offsets,
- const paddle::Tensor& cum_offsets,
- const paddle::Tensor& block_table,
- const paddle::Tensor& batch_ids,
- const paddle::Tensor& tile_ids_per_batch,
+ const AppendAttnMetaData &meta_data,
+ const paddle::Tensor
+ &qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
+ const paddle::Tensor
+ &cache_k, // [max_block_num, num_heads, block_size, head_dim]
+ const paddle::Tensor
+ &cache_v, // [max_block_num, num_heads, head_dim, block_size]
+ const paddle::optional &attn_mask,
+ const paddle::optional
+ &cache_k_scale, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &cache_v_scale, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &cache_k_zp, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &cache_v_zp, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &shift_bias, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &smooth_weight, // [num_kv_heads, head_dim]
+ const paddle::Tensor &seq_lens_q,
+ const paddle::Tensor &seq_lens_kv,
+ const paddle::Tensor &seq_lens_encoder,
+ const paddle::Tensor &padding_offsets,
+ const paddle::Tensor &cum_offsets,
+ const paddle::Tensor &block_table,
+ const paddle::Tensor &batch_ids,
+ const paddle::Tensor &tile_ids_per_batch,
const int num_blocks,
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -1303,14 +1549,15 @@ void CascadeAppendAttentionC16Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
- cudaStream_t& stream,
- paddle::Tensor* out) {
+ cudaStream_t &stream,
+ paddle::Tensor *out) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
const auto num_heads = meta_data.q_num_heads;
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
- const auto head_dim = meta_data.head_dims;
+ const auto head_dim_qk = meta_data.head_dims;
+ const auto head_dim_v = meta_data.head_dims_v;
DISPATCH_CAUSAL(
causal,
@@ -1322,46 +1569,51 @@ void CascadeAppendAttentionC16Kernel(
group_size,
GROUP_SIZE,
{DISPATCH_HEAD_DIM(
- head_dim,
- HEAD_DIM,
- {DISPATCH_BLOCK_SIZE(
- block_size,
- BLOCK_SIZE,
- {DISPATCH_BLOCKSHAPE_Q(
- block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, {
- MultiQueryAppendAttention(
- meta_data,
- qkv,
- cache_k,
- cache_v,
- attn_mask,
- shift_bias,
- smooth_weight,
- seq_lens_q,
- seq_lens_kv,
- seq_lens_encoder,
- padding_offsets,
- cum_offsets,
- block_table,
- batch_ids,
- tile_ids_per_batch,
- num_blocks,
- max_seq_len,
- max_dec_len,
- quant_max_bound,
- quant_min_bound,
- in_scale,
- speculate_max_draft_token_num,
- is_decoder,
- stream,
- out);
- })})})})})})
+ head_dim_qk,
+ HEAD_DIM_QK,
+ {DISPATCH_HEAD_DIM(
+ head_dim_v,
+ HEAD_DIM_V,
+ {DISPATCH_BLOCK_SIZE(
+ block_size,
+ BLOCK_SIZE,
+ {DISPATCH_BLOCKSHAPE_Q(
+ block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, {
+ MultiQueryAppendAttention(
+ meta_data,
+ qkv,
+ cache_k,
+ cache_v,
+ attn_mask,
+ shift_bias,
+ smooth_weight,
+ seq_lens_q,
+ seq_lens_kv,
+ seq_lens_encoder,
+ padding_offsets,
+ cum_offsets,
+ block_table,
+ batch_ids,
+ tile_ids_per_batch,
+ num_blocks,
+ max_seq_len,
+ max_dec_len,
+ softmax_scale,
+ quant_max_bound,
+ quant_min_bound,
+ in_scale,
+ speculate_max_draft_token_num,
+ is_decoder,
+ stream,
+ out);
+ })})})})})})})
}
diff --git a/csrc/gpu/append_attn/append_attention_c4_impl.cuh b/csrc/gpu/append_attn/append_attention_c4_impl.cuh
index 7d49de3966e0..fac1baf6f4c2 100644
--- a/csrc/gpu/append_attn/append_attention_c4_impl.cuh
+++ b/csrc/gpu/append_attn/append_attention_c4_impl.cuh
@@ -51,7 +51,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
- const float scale,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -189,7 +189,7 @@ __global__ void multi_query_append_attention_c4_kernel(
__syncthreads();
q_smem_inplace_multiply_sm_scale(&qo_smem,
- scale);
+ softmax_scale);
T cache_k_scale_frag[num_frags_y][4];
T cache_k_zp_frag[num_frags_y][4];
@@ -509,7 +509,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
- const float scale,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -649,7 +649,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
__syncthreads();
q_smem_inplace_multiply_sm_scale_multi_warps(
- &qo_smem, scale);
+ &qo_smem, softmax_scale);
T cache_k_scale_frag[num_frags_y][4];
T cache_k_zp_frag[num_frags_y][4];
@@ -970,6 +970,7 @@ void MultiQueryAppendC4Attention(
const int num_blocks_x_cpu,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -994,8 +995,6 @@ void MultiQueryAppendC4Attention(
auto *allocator = paddle::GetAllocator(qkv.place());
- const float scale = 1.f / sqrt(HEAD_DIM);
-
if constexpr (NUM_WARP_Q == 4) {
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
constexpr uint32_t smem_size =
@@ -1091,7 +1090,7 @@ void MultiQueryAppendC4Attention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1154,7 +1153,7 @@ void MultiQueryAppendC4Attention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1336,7 +1335,7 @@ void MultiQueryAppendC4Attention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1412,7 +1411,7 @@ void MultiQueryAppendC4Attention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1533,6 +1532,7 @@ void CascadeAppendAttentionC4Kernel(
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -1597,6 +1597,7 @@ void CascadeAppendAttentionC4Kernel(
num_blocks,
max_seq_len,
max_dec_len,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
diff --git a/csrc/gpu/append_attn/append_attention_c8_impl.cuh b/csrc/gpu/append_attn/append_attention_c8_impl.cuh
index e0ede51a9c81..df2357bb192b 100644
--- a/csrc/gpu/append_attn/append_attention_c8_impl.cuh
+++ b/csrc/gpu/append_attn/append_attention_c8_impl.cuh
@@ -32,7 +32,7 @@ template
__global__ void multi_query_append_attention_c8_kernel(
- T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
+ T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
@@ -49,7 +49,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
- const float scale,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -172,7 +172,7 @@ __global__ void multi_query_append_attention_c8_kernel(
__syncthreads();
q_smem_inplace_multiply_sm_scale(&qo_smem,
- scale);
+ softmax_scale);
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
@@ -206,8 +206,7 @@ __global__ void multi_query_append_attention_c8_kernel(
uint32_t k_smem_offset_w =
smem_t::get_permuted_offset(
- wid * 4 + tid / 8,
- tid % 8);
+ wid * 4 + tid / 8, tid % 8);
uint32_t v_smem_offset_w =
smem_t::get_permuted_offset(
wid * 8 + tid / 4, tid % 4); // 4 * 128 / 8 = 64
@@ -338,7 +337,6 @@ __global__ void multi_query_append_attention_c8_kernel(
chunk_end,
const_v_offset);
commit_group();
-
}
wait_group<0>();
__syncthreads();
@@ -434,7 +432,7 @@ template
__global__ void multi_query_append_attention_c8_warp1_4_kernel(
- T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
+ T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
@@ -451,7 +449,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
- const float scale,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -575,7 +573,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
__syncthreads();
q_smem_inplace_multiply_sm_scale_multi_warps(
- &qo_smem, scale);
+ &qo_smem, softmax_scale);
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
@@ -610,12 +608,10 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
uint32_t k_smem_offset_w =
smem_t::get_permuted_offset(
- wid * 4 + tid / 8,
- tid %
- 8);
+ wid * 4 + tid / 8, tid % 8);
uint32_t v_smem_offset_w =
smem_t::get_permuted_offset(
- wid * 8 + tid / 4, tid % 4);
+ wid * 8 + tid / 4, tid % 4);
uint32_t kv_idx_base = chunk_start;
const uint32_t const_k_offset = kv_head_idx * kv_h_stride +
@@ -805,7 +801,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
if (qo_idx - q_start_seq_id < q_len) {
-
uint32_t offset;
if (ENABLE_PREFILL) {
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
@@ -857,6 +852,7 @@ void MultiQueryAppendC8Attention(
const int num_blocks_x_cpu,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -881,8 +877,6 @@ void MultiQueryAppendC8Attention(
auto *allocator = paddle::GetAllocator(qkv.place());
- const float scale = 1.f / sqrt(HEAD_DIM);
-
if constexpr (NUM_WARP_Q == 4) {
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
constexpr uint32_t smem_size =
@@ -963,7 +957,7 @@ void MultiQueryAppendC8Attention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1020,7 +1014,7 @@ void MultiQueryAppendC8Attention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1069,8 +1063,7 @@ void MultiQueryAppendC8Attention(
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
- dim3 grids_merge(min(sm_count * 4, token_num),
- num_heads);
+ dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel
void CascadeAppendAttentionC8Kernel(
- const AppendAttnMetaData& meta_data,
- const paddle::Tensor& qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
- const paddle::Tensor&
- cache_k, // [max_block_num, num_heads, block_size, head_dim]
- const paddle::Tensor&
- cache_v, // [max_block_num, num_heads, head_dim, block_size]
- const paddle::optional& attn_mask,
- const paddle::optional&
- cache_k_scale, // [num_kv_heads, head_dim]
- const paddle::optional&
- cache_v_scale, // [num_kv_heads, head_dim]
- const paddle::optional&
- cache_k_zp, // [num_kv_heads, head_dim]
- const paddle::optional&
- cache_v_zp, // [num_kv_heads, head_dim]
- const paddle::optional&
- shift_bias, // [num_kv_heads, head_dim]
- const paddle::optional&
- smooth_weight, // [num_kv_heads, head_dim]
- const paddle::Tensor& seq_lens_q,
- const paddle::Tensor& seq_lens_kv,
- const paddle::Tensor& seq_lens_encoder,
- const paddle::Tensor& padding_offsets,
- const paddle::Tensor& cum_offsets,
- const paddle::Tensor& block_table,
- const paddle::Tensor& batch_ids,
- const paddle::Tensor& tile_ids_per_batch,
+ const AppendAttnMetaData &meta_data,
+ const paddle::Tensor
+ &qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
+ const paddle::Tensor
+ &cache_k, // [max_block_num, num_heads, block_size, head_dim]
+ const paddle::Tensor
+ &cache_v, // [max_block_num, num_heads, head_dim, block_size]
+ const paddle::optional &attn_mask,
+ const paddle::optional
+ &cache_k_scale, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &cache_v_scale, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &cache_k_zp, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &cache_v_zp, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &shift_bias, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &smooth_weight, // [num_kv_heads, head_dim]
+ const paddle::Tensor &seq_lens_q,
+ const paddle::Tensor &seq_lens_kv,
+ const paddle::Tensor &seq_lens_encoder,
+ const paddle::Tensor &padding_offsets,
+ const paddle::Tensor &cum_offsets,
+ const paddle::Tensor &block_table,
+ const paddle::Tensor &batch_ids,
+ const paddle::Tensor &tile_ids_per_batch,
const int num_blocks,
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -1379,8 +1373,8 @@ void CascadeAppendAttentionC8Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
- cudaStream_t& stream,
- paddle::Tensor* out) {
+ cudaStream_t &stream,
+ paddle::Tensor *out) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
@@ -1434,6 +1428,7 @@ void CascadeAppendAttentionC8Kernel(
num_blocks,
max_seq_len,
max_dec_len,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
diff --git a/csrc/gpu/append_attn/append_attention_kernel.h b/csrc/gpu/append_attn/append_attention_kernel.h
index b0fabcf893d3..b34c2a044733 100644
--- a/csrc/gpu/append_attn/append_attention_kernel.h
+++ b/csrc/gpu/append_attn/append_attention_kernel.h
@@ -49,6 +49,7 @@ void CascadeAppendAttentionC16Kernel(
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -92,6 +93,7 @@ void CascadeAppendAttentionC8Kernel(
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -135,6 +137,7 @@ void CascadeAppendAttentionC4Kernel(
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -179,6 +182,7 @@ void CascadeAppendAttentionKernel(
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -212,6 +216,7 @@ void CascadeAppendAttentionKernel(
block_shape_q,
max_seq_len,
max_dec_len,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -245,6 +250,7 @@ void CascadeAppendAttentionKernel(
block_shape_q,
max_seq_len,
max_dec_len,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -278,6 +284,7 @@ void CascadeAppendAttentionKernel(
block_shape_q,
max_seq_len,
max_dec_len,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
diff --git a/csrc/gpu/append_attn/decoder_write_cache_with_rope_impl.cuh b/csrc/gpu/append_attn/decoder_write_cache_with_rope_impl.cuh
index 1a8e73759022..5fbb53f05801 100644
--- a/csrc/gpu/append_attn/decoder_write_cache_with_rope_impl.cuh
+++ b/csrc/gpu/append_attn/decoder_write_cache_with_rope_impl.cuh
@@ -122,6 +122,91 @@ __global__ void append_decode_cache_T_rope_kernel(
}
}
+template
+__global__ void append_decode_cache_T_kernel(
+ const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
+ // head_size]
+ T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
+ // head_size // 2]
+ T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
+ // head_size // 2]
+ const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
+ const int* __restrict__ padding_offsets, // [num_tokens]
+ const int* __restrict__ cum_offsets,
+ const int* __restrict__ seq_lens, // [bsz]
+ const int* __restrict__ seq_lens_encoder, // [bsz]
+ const int max_seq_len,
+ const int max_blocks_per_seq,
+ const int num_heads,
+ const int head_size_qk,
+ const int head_size_v,
+ const int block_size,
+ const uint32_t elem_cnt,
+ const int kv_num_heads) {
+ using LoadT = AlignedVector;
+ using LoadBiasT = AlignedVector;
+ using LoadKVT = AlignedVector;
+ constexpr int HalfVecSize = VecSize / 2;
+ using LoadEmbT = AlignedVector;
+ LoadT src_vec;
+ LoadBiasT out_vec;
+ LoadKVT cache_vec;
+
+ int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
+ // const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
+ const uint32_t hidden_size_q = num_heads * head_size_qk;
+ const uint32_t hidden_size_k = kv_num_heads * head_size_qk;
+ const uint32_t hidden_size_v = kv_num_heads * head_size_v;
+ const int64_t hidden_size = hidden_size_q + hidden_size_k + hidden_size_v;
+ const uint32_t offset = kv_num_heads * (head_size_qk + head_size_v);
+ // const int64_t offset = 2 * hidden_size;
+ // const int half_head_size = head_size / 2;
+ for (int32_t linear_index = global_thread_idx * VecSize,
+ step = gridDim.x * blockDim.x * VecSize;
+ linear_index < elem_cnt;
+ linear_index += step) {
+ const int ori_bi = linear_index / offset;
+ const int bias = linear_index % offset;
+ const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
+ if (seq_lens_encoder[ori_bi] > 0) return;
+ const int write_seq_id = seq_lens[ori_bi];
+
+ if (write_seq_id == 0) continue;
+
+ const int* block_table_now = nullptr;
+
+ block_table_now = block_tables + ori_bi * max_blocks_per_seq;
+ const int block_idx = block_table_now[write_seq_id / block_size];
+ const int block_offset = write_seq_id % block_size;
+
+ if (bias < hidden_size_k) {
+ const uint32_t qkv_bias = bias;
+ const uint32_t hi = qkv_bias / head_size_qk;
+ const uint32_t h_bias = qkv_bias % head_size_qk;
+ const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * head_size_qk +
+ hi * block_size * head_size_qk +
+ block_offset * head_size_qk + h_bias;
+ const uint32_t ori_idx =
+ start_token_idx * hidden_size +
+ hidden_size_q + qkv_bias;
+ Load(&qkv[ori_idx], &src_vec);
+ Store(src_vec, &key_cache[tgt_idx]);
+ } else {
+ const uint32_t qkv_bias = bias - hidden_size_k;
+ const uint32_t hi = qkv_bias / head_size_v;
+ const uint32_t h_bias = qkv_bias % head_size_v;
+ const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * head_size_v +
+ hi * block_size * head_size_v +
+ block_offset * head_size_v + h_bias;
+ const uint32_t ori_idx =
+ start_token_idx * hidden_size +
+ hidden_size_q + hidden_size_k + qkv_bias;
+ Load(&qkv[ori_idx], &src_vec);
+ Store(src_vec, &value_cache[tgt_idx]);
+ }
+ }
+}
+
template
__global__ void append_decode_cache_T_rope_kernel(
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
diff --git a/csrc/gpu/append_attn/decoder_write_cache_with_rope_kernel.cu b/csrc/gpu/append_attn/decoder_write_cache_with_rope_kernel.cu
index ee0cd57e307c..08483feb2a5c 100644
--- a/csrc/gpu/append_attn/decoder_write_cache_with_rope_kernel.cu
+++ b/csrc/gpu/append_attn/decoder_write_cache_with_rope_kernel.cu
@@ -15,6 +15,54 @@
#include "decoder_write_cache_with_rope_kernel.h"
#include "utils.cuh"
+
+template
+void DecoderWriteCacheKV(const AppendAttnMetaData& meta_data,
+ const paddle::Tensor& qkv,
+ const paddle::Tensor& seq_lens,
+ const paddle::Tensor& seq_lens_encoder,
+ const paddle::Tensor& padding_offsets,
+ const paddle::Tensor& cum_offsets,
+ const paddle::Tensor& block_tables,
+ const int max_seq_len,
+ cudaStream_t& stream,
+ paddle::Tensor* key_cache_out,
+ paddle::Tensor* value_cache_out) {
+ auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
+ auto bsz = meta_data.batch_size;
+ auto block_size = meta_data.block_size;
+ auto head_dim_qk = meta_data.head_dims;
+ auto head_dim_v = meta_data.head_dims_v;
+ auto num_heads = meta_data.q_num_heads;
+ auto kv_num_heads = meta_data.kv_num_heads;
+ const uint32_t elem_nums = bsz * kv_num_heads * (head_dim_qk + head_dim_v);
+
+ constexpr int PackSize = 16 / sizeof(T);
+ const int pack_num = elem_nums / PackSize;
+ const int blocksize = 128;
+ int grid_size = 1;
+ GetNumBlocks<128>(pack_num, &grid_size);
+
+ append_decode_cache_T_kernel
+ <<>>(
+ reinterpret_cast(const_cast(qkv.data())),
+ reinterpret_cast(key_cache_out->data()),
+ reinterpret_cast(value_cache_out->data()),
+ block_tables.data(),
+ padding_offsets.data(),
+ cum_offsets.data(),
+ seq_lens.data(),
+ seq_lens_encoder.data(),
+ max_seq_len,
+ max_blocks_per_seq,
+ num_heads,
+ head_dim_qk,
+ head_dim_v,
+ block_size,
+ elem_nums,
+ kv_num_heads);
+}
+
template
void append_decode_cache_rope(const QKV_TYPE* qkv,
T* key_cache,
@@ -449,115 +497,125 @@ void DecoderWriteCacheWithRoPEKernel(
auto num_heads = meta_data.q_num_heads;
auto kv_num_heads = meta_data.kv_num_heads;
- const float* cos_emb =
- rotary_embs ? rotary_embs.get().data() : nullptr;
- const float* sin_emb;
if (rotary_embs) {
- sin_emb =
+ const float* cos_emb = rotary_embs.get().data();
+ const float* sin_emb =
use_neox_rotary_style
? rotary_embs.get().data() + max_seq_len * dim_head
: rotary_embs.get().data() + max_seq_len * dim_head / 2;
- }
- if (cache_quant_type_str == "none") {
- append_decode_cache_rope(
- reinterpret_cast(qkv_ptr),
- reinterpret_cast(key_cache_out->data()),
- reinterpret_cast(value_cache_out->data()),
- reinterpret_cast(qkv_out->data()),
- block_tables.data(),
- padding_offsets.data(),
- cum_offsets.data(),
- seq_lens.data(),
- seq_lens_encoder.data(),
- cos_emb,
- sin_emb,
- qkv_out_scales ? qkv_out_scales.get().data() : nullptr,
- qkv_biases ? reinterpret_cast(
- const_cast(qkv_biases.get().data()))
- : nullptr,
- max_seq_len,
- max_blocks_per_seq,
- num_heads,
- kv_num_heads,
- dim_head,
- block_size,
- bsz,
- stream,
- use_neox_rotary_style);
- } else if (cache_quant_type_str == "cache_int8") {
- append_decode_cache_int8_rope(
- reinterpret_cast(qkv_ptr),
- key_cache_out->data(),
- value_cache_out->data(),
- reinterpret_cast(qkv_out->data()),
- block_tables.data(),
- padding_offsets.data(),
- cum_offsets.data(),
- seq_lens.data(),
- seq_lens_encoder.data(),
- cos_emb,
- sin_emb,
- qkv_out_scales ? qkv_out_scales.get().data() : nullptr,
- qkv_biases ? reinterpret_cast(
- const_cast(qkv_biases.get().data()))
- : nullptr,
- cache_k_scale ? reinterpret_cast(
- const_cast(cache_k_scale.get().data()))
- : nullptr,
- cache_v_scale ? reinterpret_cast(
- const_cast(cache_v_scale.get().data()))
- : nullptr,
- max_seq_len,
- max_blocks_per_seq,
- num_heads,
- kv_num_heads,
- dim_head,
- block_size,
- bsz,
- stream,
- use_neox_rotary_style);
- } else if (cache_quant_type_str == "cache_int4_zp") {
- append_decode_cache_int4_rope(
- reinterpret_cast(qkv_ptr),
- key_cache_out->data(),
- value_cache_out->data(),
- reinterpret_cast(const_cast(qkv_out->data())),
- block_tables.data(),
- padding_offsets.data(),
- cum_offsets.data