-
* **2024.08.08 📚《飞桨产业级大语言模型开发利器 PaddleNLP 3.0 重磅发布》**,训压推全流程贯通,主流模型全覆盖。大模型自动并行,千亿模型训推全流程开箱即用。提供产业级高性能精调与对齐解决方案,压缩推理领先,多硬件适配。覆盖产业级智能助手、内容创作、知识问答、关键信息抽取等应用场景。直播时间:8月22日(周四)19:00。报名链接:https://www.wjx.top/vm/Y2f7FFY.aspx?udsid=143844
* **2024.06.27 [PaddleNLP v3.0 Beta](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v3.0.0-beta0)**:拥抱大模型,体验全升级。统一大模型套件,实现国产计算芯片全流程接入;全面支持飞桨4D 并行配置、高效精调策略、高效对齐算法、高性能推理等大模型产业级应用流程;自研极致收敛的 RsLoRA+算法、自动扩缩容存储机制 Unified Checkpoint 和通用化支持的 FastFFN、FusedQKV 助力大模型训推;主流模型持续支持更新,提供高效解决方案。
@@ -79,35 +83,36 @@
* 模型参数已支持 LLaMA 系列、Baichuan 系列、Bloom 系列、ChatGLM 系列、Gemma 系列、Mistral 系列、OPT 系列和 Qwen 系列,详细列表👉【LLM】模型参数支持列表如下:
-| 模型系列 | 模型名称 |
-|:-------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| [LLaMA](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | facebook/llama-7b, facebook/llama-13b, facebook/llama-30b, facebook/llama-65b |
-| [Llama2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-2-7b, meta-llama/Llama-2-7b-chat, meta-llama/Llama-2-13b, meta-llama/Llama-2-13b-chat, meta-llama/Llama-2-70b, meta-llama/Llama-2-70b-chat |
-| [Llama3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3-8B, meta-llama/Meta-Llama-3-8B-Instruct, meta-llama/Meta-Llama-3-70B, meta-llama/Meta-Llama-3-70B-Instruct |
-| [Llama3.1](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3.1-8B, meta-llama/Meta-Llama-3.1-8B-Instruct, meta-llama/Meta-Llama-3.1-70B, meta-llama/Meta-Llama-3.1-70B-Instruct, meta-llama/Meta-Llama-3.1-405B, meta-llama/Meta-Llama-3.1-405B-Instruct, meta-llama/Llama-Guard-3-8B |
-| [Llama3.2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-3.2-1B, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.2-3B, meta-llama/Llama-3.2-3B-Instruct, meta-llama/Llama-Guard-3-1B |
-| [Llama3.3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-3.3-70B-Instruct |
-| [Baichuan](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan-7B, baichuan-inc/Baichuan-13B-Base, baichuan-inc/Baichuan-13B-Chat |
-| [Baichuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan2-7B-Base, baichuan-inc/Baichuan2-7B-Chat, baichuan-inc/Baichuan2-13B-Base, baichuan-inc/Baichuan2-13B-Chat |
-| [Bloom](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/bloom) | bigscience/bloom-560m, bigscience/bloom-560m-bf16, bigscience/bloom-1b1, bigscience/bloom-3b, bigscience/bloom-7b1, bigscience/bloomz-560m, bigscience/bloomz-1b1, bigscience/bloomz-3b, bigscience/bloomz-7b1-mt, bigscience/bloomz-7b1-p3, bigscience/bloomz-7b1, bellegroup/belle-7b-2m |
-| [ChatGLM](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm/) | THUDM/chatglm-6b, THUDM/chatglm-6b-v1.1 |
-| [ChatGLM2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm2-6b |
-| [ChatGLM3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm3-6b |
-| [DeepSeekV2](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V2, deepseek-ai/DeepSeek-V2-Chat, deepseek-ai/DeepSeek-V2-Lite, deepseek-ai/DeepSeek-V2-Lite-Chat, deepseek-ai/DeepSeek-Coder-V2-Base, deepseek-ai/DeepSeek-Coder-V2-Instruct, deepseek-ai/DeepSeek-Coder-V2-Lite-Base, deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct |
-| [DeepSeekV3](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V3, deepseek-ai/DeepSeek-V3-Base |
-| [DeepSeek-R1](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-R1, deepseek-ai/DeepSeek-R1-Zero, deepseek-ai/DeepSeek-R1-Distill-Llama-70B, deepseek-ai/DeepSeek-R1-Distill-Llama-8B, deepseek-ai/DeepSeek-R1-Distill-Qwen-14B, deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, deepseek-ai/DeepSeek-R1-Distill-Qwen-32B, deepseek-ai/DeepSeek-R1-Distill-Qwen-7B |
-| [Gemma](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/gemma) | google/gemma-7b, google/gemma-7b-it, google/gemma-2b, google/gemma-2b-it |
-| [Mistral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mistral) | mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-7B-v0.1 |
-| [Mixtral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mixtral) | mistralai/Mixtral-8x7B-Instruct-v0.1 |
-| [OPT](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/opt) | facebook/opt-125m, facebook/opt-350m, facebook/opt-1.3b, facebook/opt-2.7b, facebook/opt-6.7b, facebook/opt-13b, facebook/opt-30b, facebook/opt-66b, facebook/opt-iml-1.3b, opt-iml-max-1.3b |
-| [Qwen](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | qwen/qwen-7b, qwen/qwen-7b-chat, qwen/qwen-14b, qwen/qwen-14b-chat, qwen/qwen-72b, qwen/qwen-72b-chat, |
-| [Qwen1.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen1.5-0.5B, Qwen/Qwen1.5-0.5B-Chat, Qwen/Qwen1.5-1.8B, Qwen/Qwen1.5-1.8B-Chat, Qwen/Qwen1.5-4B, Qwen/Qwen1.5-4B-Chat, Qwen/Qwen1.5-7B, Qwen/Qwen1.5-7B-Chat, Qwen/Qwen1.5-14B, Qwen/Qwen1.5-14B-Chat, Qwen/Qwen1.5-32B, Qwen/Qwen1.5-32B-Chat, Qwen/Qwen1.5-72B, Qwen/Qwen1.5-72B-Chat, Qwen/Qwen1.5-110B, Qwen/Qwen1.5-110B-Chat, Qwen/Qwen1.5-MoE-A2.7B, Qwen/Qwen1.5-MoE-A2.7B-Chat |
-| [Qwen2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-0.5B, Qwen/Qwen2-0.5B-Instruct, Qwen/Qwen2-1.5B, Qwen/Qwen2-1.5B-Instruct, Qwen/Qwen2-7B, Qwen/Qwen2-7B-Instruct, Qwen/Qwen2-72B, Qwen/Qwen2-72B-Instruct, Qwen/Qwen2-57B-A14B, Qwen/Qwen2-57B-A14B-Instruct |
-| [Qwen2-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-Math-1.5B, Qwen/Qwen2-Math-1.5B-Instruct, Qwen/Qwen2-Math-7B, Qwen/Qwen2-Math-7B-Instruct, Qwen/Qwen2-Math-72B, Qwen/Qwen2-Math-72B-Instruct, Qwen/Qwen2-Math-RM-72B |
-| [Qwen2.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-0.5B, Qwen/Qwen2.5-0.5B-Instruct, Qwen/Qwen2.5-1.5B, Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-3B, Qwen/Qwen2.5-3B-Instruct, Qwen/Qwen2.5-7B, Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen2.5-14B, Qwen/Qwen2.5-14B-Instruct, Qwen/Qwen2.5-32B, Qwen/Qwen2.5-32B-Instruct, Qwen/Qwen2.5-72B, Qwen/Qwen2.5-72B-Instruct |
-| [Qwen2.5-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Math-1.5B, Qwen/Qwen2.5-Math-1.5B-Instruct, Qwen/Qwen2.5-Math-7B, Qwen/Qwen2.5-Math-7B-Instruct, Qwen/Qwen2.5-Math-72B, Qwen/Qwen2.5-Math-72B-Instruct, Qwen/Qwen2.5-Math-RM-72B |
-| [Qwen2.5-Coder](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Coder-1.5B, Qwen/Qwen2.5-Coder-1.5B-Instruct, Qwen/Qwen2.5-Coder-7B, Qwen/Qwen2.5-Coder-7B-Instruct |
-| [Yuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/yuan/) | IEITYuan/Yuan2-2B, IEITYuan/Yuan2-51B, IEITYuan/Yuan2-102B |
+| 模型系列 | 模型名称 |
+|:-------------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [PP-UIE](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/application/information_extraction) | paddlenlp/PP-UIE-0.5B, paddlenlp/PP-UIE-1.5B, paddlenlp/PP-UIE-7B, paddlenlp/PP-UIE-14B |
+| [LLaMA](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | facebook/llama-7b, facebook/llama-13b, facebook/llama-30b, facebook/llama-65b |
+| [Llama2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-2-7b, meta-llama/Llama-2-7b-chat, meta-llama/Llama-2-13b, meta-llama/Llama-2-13b-chat, meta-llama/Llama-2-70b, meta-llama/Llama-2-70b-chat |
+| [Llama3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3-8B, meta-llama/Meta-Llama-3-8B-Instruct, meta-llama/Meta-Llama-3-70B, meta-llama/Meta-Llama-3-70B-Instruct |
+| [Llama3.1](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3.1-8B, meta-llama/Meta-Llama-3.1-8B-Instruct, meta-llama/Meta-Llama-3.1-70B, meta-llama/Meta-Llama-3.1-70B-Instruct, meta-llama/Meta-Llama-3.1-405B, meta-llama/Meta-Llama-3.1-405B-Instruct, meta-llama/Llama-Guard-3-8B |
+| [Llama3.2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-3.2-1B, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.2-3B, meta-llama/Llama-3.2-3B-Instruct, meta-llama/Llama-Guard-3-1B |
+| [Llama3.3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-3.3-70B-Instruct |
+| [Baichuan](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan-7B, baichuan-inc/Baichuan-13B-Base, baichuan-inc/Baichuan-13B-Chat |
+| [Baichuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan2-7B-Base, baichuan-inc/Baichuan2-7B-Chat, baichuan-inc/Baichuan2-13B-Base, baichuan-inc/Baichuan2-13B-Chat |
+| [Bloom](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/bloom) | bigscience/bloom-560m, bigscience/bloom-560m-bf16, bigscience/bloom-1b1, bigscience/bloom-3b, bigscience/bloom-7b1, bigscience/bloomz-560m, bigscience/bloomz-1b1, bigscience/bloomz-3b, bigscience/bloomz-7b1-mt, bigscience/bloomz-7b1-p3, bigscience/bloomz-7b1, bellegroup/belle-7b-2m |
+| [ChatGLM](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm/) | THUDM/chatglm-6b, THUDM/chatglm-6b-v1.1 |
+| [ChatGLM2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm2-6b |
+| [ChatGLM3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm3-6b |
+| [DeepSeekV2](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V2, deepseek-ai/DeepSeek-V2-Chat, deepseek-ai/DeepSeek-V2-Lite, deepseek-ai/DeepSeek-V2-Lite-Chat, deepseek-ai/DeepSeek-Coder-V2-Base, deepseek-ai/DeepSeek-Coder-V2-Instruct, deepseek-ai/DeepSeek-Coder-V2-Lite-Base, deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct |
+| [DeepSeekV3](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V3, deepseek-ai/DeepSeek-V3-Base |
+| [DeepSeek-R1](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-R1, deepseek-ai/DeepSeek-R1-Zero, deepseek-ai/DeepSeek-R1-Distill-Llama-70B, deepseek-ai/DeepSeek-R1-Distill-Llama-8B, deepseek-ai/DeepSeek-R1-Distill-Qwen-14B, deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, deepseek-ai/DeepSeek-R1-Distill-Qwen-32B, deepseek-ai/DeepSeek-R1-Distill-Qwen-7B |
+| [Gemma](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/gemma) | google/gemma-7b, google/gemma-7b-it, google/gemma-2b, google/gemma-2b-it |
+| [Mistral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mistral) | mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-7B-v0.1 |
+| [Mixtral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mixtral) | mistralai/Mixtral-8x7B-Instruct-v0.1 |
+| [OPT](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/opt) | facebook/opt-125m, facebook/opt-350m, facebook/opt-1.3b, facebook/opt-2.7b, facebook/opt-6.7b, facebook/opt-13b, facebook/opt-30b, facebook/opt-66b, facebook/opt-iml-1.3b, opt-iml-max-1.3b |
+| [Qwen](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | qwen/qwen-7b, qwen/qwen-7b-chat, qwen/qwen-14b, qwen/qwen-14b-chat, qwen/qwen-72b, qwen/qwen-72b-chat, |
+| [Qwen1.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen1.5-0.5B, Qwen/Qwen1.5-0.5B-Chat, Qwen/Qwen1.5-1.8B, Qwen/Qwen1.5-1.8B-Chat, Qwen/Qwen1.5-4B, Qwen/Qwen1.5-4B-Chat, Qwen/Qwen1.5-7B, Qwen/Qwen1.5-7B-Chat, Qwen/Qwen1.5-14B, Qwen/Qwen1.5-14B-Chat, Qwen/Qwen1.5-32B, Qwen/Qwen1.5-32B-Chat, Qwen/Qwen1.5-72B, Qwen/Qwen1.5-72B-Chat, Qwen/Qwen1.5-110B, Qwen/Qwen1.5-110B-Chat, Qwen/Qwen1.5-MoE-A2.7B, Qwen/Qwen1.5-MoE-A2.7B-Chat |
+| [Qwen2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-0.5B, Qwen/Qwen2-0.5B-Instruct, Qwen/Qwen2-1.5B, Qwen/Qwen2-1.5B-Instruct, Qwen/Qwen2-7B, Qwen/Qwen2-7B-Instruct, Qwen/Qwen2-72B, Qwen/Qwen2-72B-Instruct, Qwen/Qwen2-57B-A14B, Qwen/Qwen2-57B-A14B-Instruct |
+| [Qwen2-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-Math-1.5B, Qwen/Qwen2-Math-1.5B-Instruct, Qwen/Qwen2-Math-7B, Qwen/Qwen2-Math-7B-Instruct, Qwen/Qwen2-Math-72B, Qwen/Qwen2-Math-72B-Instruct, Qwen/Qwen2-Math-RM-72B |
+| [Qwen2.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-0.5B, Qwen/Qwen2.5-0.5B-Instruct, Qwen/Qwen2.5-1.5B, Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-3B, Qwen/Qwen2.5-3B-Instruct, Qwen/Qwen2.5-7B, Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen2.5-14B, Qwen/Qwen2.5-14B-Instruct, Qwen/Qwen2.5-32B, Qwen/Qwen2.5-32B-Instruct, Qwen/Qwen2.5-72B, Qwen/Qwen2.5-72B-Instruct |
+| [Qwen2.5-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Math-1.5B, Qwen/Qwen2.5-Math-1.5B-Instruct, Qwen/Qwen2.5-Math-7B, Qwen/Qwen2.5-Math-7B-Instruct, Qwen/Qwen2.5-Math-72B, Qwen/Qwen2.5-Math-72B-Instruct, Qwen/Qwen2.5-Math-RM-72B |
+| [Qwen2.5-Coder](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Coder-1.5B, Qwen/Qwen2.5-Coder-1.5B-Instruct, Qwen/Qwen2.5-Coder-7B, Qwen/Qwen2.5-Coder-7B-Instruct |
+| [Yuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/yuan/) | IEITYuan/Yuan2-2B, IEITYuan/Yuan2-51B, IEITYuan/Yuan2-102B |
* 4D 并行和算子优化已支持 LLaMA 系列、Baichuan 系列、Bloom 系列、ChatGLM 系列、Gemma 系列、Mistral 系列、OPT 系列和 Qwen 系列,【LLM】模型4D 并行和算子支持列表如下:
@@ -166,7 +171,7 @@
### 环境依赖
* python >= 3.8
-* paddlepaddle >= 3.0.0b0
+* paddlepaddle >= 3.0.0rc0
如果您尚未安装 PaddlePaddle,请参考 [飞桨官网](https://www.paddlepaddle.org.cn/) 进行安装。
@@ -211,7 +216,7 @@ wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwe
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.idx
cd .. # change folder to PaddleNLP/llm
# 如需使用use_fused_rms_norm=true,需要前往slm/model_zoo/gpt-3/external_ops安装fused_ln
-python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_pretrain.py ./config/llama/pretrain_argument.json --use_fused_rms_norm false
+python -u run_pretrain.py ./config/qwen/pretrain_argument_0p5b.json
```
### 大模型 SFT 精调
@@ -221,7 +226,7 @@ git clone https://github.com/PaddlePaddle/PaddleNLP.git && cd PaddleNLP # 如已
mkdir -p llm/data && cd llm/data
wget https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz && tar -zxvf AdvertiseGen.tar.gz
cd .. # change folder to PaddleNLP/llm
-python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_finetune.py ./config/llama/sft_argument.json
+python -u run_finetune.py ./config/qwen/sft_argument_0p5b.json
```
更多大模型全流程步骤,请参考[飞桨大模型套件](./llm)介绍。
@@ -236,7 +241,7 @@ dataset = load_dataset("ZHUI/alpaca_demo", split="train")
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT", device="gpu")
trainer = SFTTrainer(
args=training_args,
- model="Qwen/Qwen2.5-0.5B",
+ model="Qwen/Qwen2.5-0.5B-Instruct",
train_dataset=dataset,
)
trainer.train()
diff --git a/csrc/gpu/append_attention.cu b/csrc/gpu/append_attention.cu
index d24a20e48d11..d6f3efbbf3df 100644
--- a/csrc/gpu/append_attention.cu
+++ b/csrc/gpu/append_attention.cu
@@ -56,6 +56,7 @@ std::vector
AppendAttentionKernel(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
@@ -97,13 +98,13 @@ std::vector AppendAttentionKernel(
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
- {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
+ {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
paddle::DataType::INT8,
qkv.place());
}
else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
- {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
+ {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
}else{
@@ -111,7 +112,7 @@ std::vector AppendAttentionKernel(
}
} else {
fmha_out = GetEmptyTensor(
- {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
+ {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
D,
qkv.place());
}
@@ -203,6 +204,7 @@ std::vector AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -240,6 +242,7 @@ std::vector AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -282,6 +285,7 @@ std::vector AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -428,6 +432,7 @@ std::vector AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -465,6 +470,7 @@ std::vector AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -508,6 +514,7 @@ std::vector AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -565,6 +572,7 @@ std::vector AppendAttention(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
@@ -578,9 +586,10 @@ std::vector AppendAttention(
meta_data.token_nums = qkv_dims[0];
meta_data.kv_num_heads = key_cache_dims[1];
meta_data.head_dims = key_cache_dims[3];
- const int total_num_head =
- qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims;
- meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads;
+ meta_data.head_dims_v = value_cache.dims()[3];
+ const int q_hidden_size =
+ qkv_dims[qkv_dims.size() - 1] - meta_data.kv_num_heads * (meta_data.head_dims + meta_data.head_dims_v);
+ meta_data.q_num_heads = q_hidden_size / meta_data.head_dims;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
@@ -626,6 +635,7 @@ std::vector AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -672,6 +682,7 @@ std::vector AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -719,6 +730,7 @@ std::vector AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -764,6 +776,7 @@ std::vector AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
@@ -821,10 +834,12 @@ std::vector> AppendAttentionInferShape(
const paddle::optional>& out_linear_smooths_shape) {
const int token_num = qkv_shape[0];
const int kv_num_heads = key_cache_shape[1];
- const int head_dim = key_cache_shape[3];
- const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
- const int num_heads = total_num_head - 2 * kv_num_heads;
- return {{token_num, num_heads * head_dim}, qkv_shape};
+ const int head_dim_qk = key_cache_shape[3];
+ const int head_dim_v = value_cache_shape[3];
+ const int q_hidden_size =
+ qkv_shape[qkv_shape.size() - 1] - kv_num_heads * (head_dim_qk + head_dim_v);
+ const int num_heads = q_hidden_size / head_dim_qk;
+ return {{token_num, num_heads * head_dim_v}, qkv_shape};
}
std::vector AppendAttentionInferDtype(
@@ -865,6 +880,7 @@ std::vector AppendAttentionInferDtype(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
@@ -941,6 +957,7 @@ PD_BUILD_OP(append_attention)
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"max_input_length: int",
+ "softmax_scale: float",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
diff --git a/csrc/gpu/append_attn/append_attention_c16_impl.cuh b/csrc/gpu/append_attn/append_attention_c16_impl.cuh
index 3b08d0a85dbc..8b75fa13cdca 100644
--- a/csrc/gpu/append_attn/append_attention_c16_impl.cuh
+++ b/csrc/gpu/append_attn/append_attention_c16_impl.cuh
@@ -23,15 +23,17 @@ template
__global__ void multi_query_append_attention_kernel(
- T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
+ T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
T *__restrict__ cache_v,
@@ -46,7 +48,7 @@ __global__ void multi_query_append_attention_kernel(
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
- const float scale,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -57,7 +59,9 @@ __global__ void multi_query_append_attention_kernel(
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
- constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b();
+ constexpr uint32_t num_vecs_per_head_qk =
+ HEAD_DIM_QK / num_elems_per_128b();
+ constexpr uint32_t num_vecs_per_head_v = HEAD_DIM_V / num_elems_per_128b();
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
const uint32_t kv_num_heads = gridDim.z;
const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE;
@@ -104,25 +108,30 @@ __global__ void multi_query_append_attention_kernel(
extern __shared__ uint8_t smem[];
float s_frag[num_frags_x][num_frags_z][8];
- float o_frag[num_frags_x][num_frags_y][8];
+ float o_frag[num_frags_x][num_frags_y_v][8];
float m_frag[num_frags_x][2];
float d_frag[num_frags_x][2];
- init_states(o_frag, m_frag, d_frag);
-
- const uint32_t q_n_stride = q_num_heads * HEAD_DIM;
- const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM;
- const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
- const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
- const uint32_t kv_b_stride = HEAD_DIM;
+ init_states(o_frag, m_frag, d_frag);
+
+ const uint32_t q_n_stride = q_num_heads * HEAD_DIM_V;
+ const uint32_t q_ori_n_stride = q_num_heads * HEAD_DIM_QK +
+ kv_num_heads * HEAD_DIM_QK +
+ kv_num_heads * HEAD_DIM_V;
+ const uint32_t k_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM_QK;
+ const uint32_t k_h_stride = BLOCK_SIZE * HEAD_DIM_QK;
+ const uint32_t k_b_stride = HEAD_DIM_QK;
+ const uint32_t v_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM_V;
+ const uint32_t v_h_stride = BLOCK_SIZE * HEAD_DIM_V;
+ const uint32_t v_b_stride = HEAD_DIM_V;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_base_seq_id_this_block =
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
- q_head_idx * HEAD_DIM +
+ q_head_idx * HEAD_DIM_QK +
tid % 8 * num_elems_per_128b();
const uint32_t o_offset = q_start_seq_id * q_n_stride +
- q_head_idx * HEAD_DIM +
+ q_head_idx * HEAD_DIM_V +
tid % 8 * num_elems_per_128b();
T *q_base_ptr = q + q_offset;
T *o_base_ptr_T = nullptr;
@@ -130,13 +139,13 @@ __global__ void multi_query_append_attention_kernel(
if constexpr (partition_kv) {
if (ENABLE_PREFILL) {
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
- chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
+ chunk_idx * q_n_stride + q_head_idx * HEAD_DIM_V +
tid % 8 * num_elems_per_128b();
} else {
o_base_ptr_T =
tmp_workspace +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
- chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
+ chunk_idx * q_n_stride + q_head_idx * HEAD_DIM_V +
tid % 8 * num_elems_per_128b();
}
} else {
@@ -144,24 +153,42 @@ __global__ void multi_query_append_attention_kernel(
}
smem_t qo_smem(smem);
- uint32_t q_smem_offset_r = smem_t::get_permuted_offset(
+ uint32_t q_smem_offset_r = smem_t::get_permuted_offset(
wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16
- load_q_global_smem(
+ load_q_global_smem(
q_base_ptr,
&qo_smem,
q_base_seq_id_this_block,
q_end,
q_ori_n_stride,
- HEAD_DIM);
+ HEAD_DIM_QK);
commit_group();
wait_group<0>();
__syncthreads();
+#ifdef DEBUG_PERCISION
+ if (tid == 0 && threadIdx.y == 0 && blockIdx.z == 0 && blockIdx.x == 0) {
+ printf("q_smem(%d * 192个bfloat16):\n", 4 * num_frags_x * 16);
+ // const uint32_t k_num = num_frags_z * 64 * HEAD_DIM / 2 * sizeof(CacheT);
+ T *q_smem_t = reinterpret_cast(qo_smem.base);
+ for (uint32_t i = 0; i < 4 * num_frags_x * 16; ++i) {
+ printf("q_smem[%d]:", (int)i);
+ for (uint32_t j = 0; j < HEAD_DIM_QK / 8; ++j) {
+ printf("[");
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.2f ", (float)q_smem_t[i * HEAD_DIM_QK + j * 8 + k]);
+ }
+ printf("]");
+ }
+ printf("\n");
+ }
+ }
+ __syncthreads();
+#endif
+ q_smem_inplace_multiply_sm_scale(
+ &qo_smem, softmax_scale);
- q_smem_inplace_multiply_sm_scale(&qo_smem,
- scale);
-
- smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
- v_smem(smem + (NUM_WARPS * num_frags_x + num_frags_z) * 16 * HEAD_DIM *
+ smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM_QK * sizeof(T)),
+ v_smem(smem + (NUM_WARPS * num_frags_x + num_frags_z) * 16 * HEAD_DIM_QK *
sizeof(T));
@@ -182,50 +209,55 @@ __global__ void multi_query_append_attention_kernel(
chunk_start)))
: chunk_len) /
(num_frags_z * 16);
- uint32_t k_smem_offset_r = smem_t::get_permuted_offset(
+ uint32_t k_smem_offset_r = smem_t::get_permuted_offset(
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
uint32_t v_smem_offset_r =
- smem_t::get_permuted_offset(tid % 16, tid / 16);
+ smem_t::get_permuted_offset(tid % 16, tid / 16);
- uint32_t kv_smem_offset_w = smem_t::get_permuted_offset(
+ uint32_t k_smem_offset_w = smem_t::get_permuted_offset(
+ wid * 4 + tid / 8, tid % 8);
+ uint32_t v_smem_offset_w = smem_t::get_permuted_offset(
wid * 4 + tid / 8, tid % 8);
uint32_t kv_idx_base = chunk_start;
int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]);
- const uint32_t const_offset = kv_head_idx * kv_h_stride +
- (wid * 4 + tid / 8) * kv_b_stride +
- tid % 8 * num_elems_per_128b();
- T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset;
- T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset;
+ const uint32_t const_offset_k = kv_head_idx * k_h_stride +
+ (wid * 4 + tid / 8) * k_b_stride +
+ tid % 8 * num_elems_per_128b();
+ const uint32_t const_offset_v = kv_head_idx * v_h_stride +
+ (wid * 4 + tid / 8) * v_b_stride +
+ tid % 8 * num_elems_per_128b();
+ T *cache_k_now = cache_k + block_id * k_n_stride + const_offset_k;
+ T *cache_v_now = cache_v + block_id * v_n_stride + const_offset_v;
produce_kv_blockwise(k_smem,
- &kv_smem_offset_w,
+ &k_smem_offset_w,
&cache_k_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ k_n_stride,
+ k_h_stride,
+ k_b_stride,
kv_idx_base,
chunk_end);
commit_group();
produce_kv_blockwise(v_smem,
- &kv_smem_offset_w,
+ &v_smem_offset_w,
&cache_v_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ v_n_stride,
+ v_h_stride,
+ v_b_stride,
kv_idx_base,
chunk_end);
commit_group();
@@ -233,10 +265,45 @@ __global__ void multi_query_append_attention_kernel(
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
wait_group<1>();
__syncthreads();
-
+#ifdef DEBUG_PERCISION
+ if (tid == 0 && threadIdx.y == 0 && blockIdx.z == 0 && blockIdx.x == 0) {
+ printf("k_smem(%d * 192个bfloat16):\n", num_frags_z * 16);
+ // const uint32_t k_num = num_frags_z * 64 * HEAD_DIM / 2 *
+ // sizeof(CacheT);
+ T *k_smem_t = reinterpret_cast(k_smem.base);
+ for (uint32_t i = 0; i < num_frags_z * 16; ++i) {
+ printf("k_smem[%d]:", (int)i);
+ for (uint32_t j = 0; j < HEAD_DIM_QK / 8; ++j) {
+ printf("[");
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.2f ", (float)k_smem_t[i * HEAD_DIM_QK + j * 8 + k]);
+ }
+ printf("]");
+ }
+ printf("\n");
+ }
+ }
+ __syncthreads();
+#endif
// s = qk
- compute_qk(
+ compute_qk(
&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
+#ifdef DEBUG_PERCISION
+ __syncthreads();
+ if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.z == 0 &&
+ blockIdx.x == 0) {
+ for (uint32_t i = 0; i < num_frags_x; ++i) {
+ for (uint32_t j = 0; j < num_frags_z; ++j) {
+ printf("s_frag[%d][%d]:\n", i, j);
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.4f ", s_frag[i][j][k]);
+ }
+ printf("\n");
+ }
+ }
+ }
+ __syncthreads();
+#endif
// mask according to kv_idx and q_idx
if (iter >= mask_check_iteration) {
mask_s(q_base_seq_id_this_block,
kv_idx_base,
q_len,
@@ -255,7 +322,7 @@ __global__ void multi_query_append_attention_kernel(
}
// update m,d
- update_mdo_states(
+ update_mdo_states(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();
@@ -264,43 +331,77 @@ __global__ void multi_query_append_attention_kernel(
if (block_id < 0) {
block_id = 0;
}
- cache_k_now = cache_k + block_id * kv_n_stride + const_offset;
+ cache_k_now = cache_k + block_id * k_n_stride + const_offset_k;
produce_kv_blockwise(k_smem,
- &kv_smem_offset_w,
+ &k_smem_offset_w,
&cache_k_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ k_n_stride,
+ k_h_stride,
+ k_b_stride,
kv_idx_base,
chunk_end);
commit_group();
wait_group<1>();
__syncthreads();
-
+#ifdef DEBUG_PERCISION
+ if (tid == 0 && threadIdx.y == 0 && blockIdx.z == 0 && blockIdx.x == 0) {
+ printf("v_smem(%d * 128个bfloat16):\n", num_frags_z * 16);
+ // const uint32_t k_num = num_frags_z * 64 * HEAD_DIM / 2 *
+ // sizeof(CacheT);
+ T *v_smem_t = reinterpret_cast(v_smem.base);
+ for (uint32_t i = 0; i < num_frags_z * 16; ++i) {
+ printf("v_smem[%d]:", (int)i);
+ for (uint32_t j = 0; j < HEAD_DIM_V / 8; ++j) {
+ printf("[");
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.2f ", (float)v_smem_t[i * HEAD_DIM_V + j * 8 + k]);
+ }
+ printf("]");
+ }
+ printf("\n");
+ }
+ }
+ __syncthreads();
+#endif
// compute sfm*v
- compute_sfm_v(
+ compute_sfm_v(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag);
-
+#ifdef DEBUG_PERCISION
+ __syncthreads();
+ if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.z == 0 &&
+ blockIdx.x == 0) {
+ for (uint32_t i = 0; i < num_frags_x; ++i) {
+ for (uint32_t j = 0; j < num_frags_y_v; ++j) {
+ printf("o_frag[%d][%d]:\n", i, j);
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.4f ", s_frag[i][j][k]);
+ }
+ printf("\n");
+ }
+ }
+ }
__syncthreads();
- cache_v_now = cache_v + block_id * kv_n_stride + const_offset;
+#endif
+ __syncthreads();
+ cache_v_now = cache_v + block_id * v_n_stride + const_offset_v;
produce_kv_blockwise(v_smem,
- &kv_smem_offset_w,
+ &v_smem_offset_w,
&cache_v_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ v_n_stride,
+ v_h_stride,
+ v_b_stride,
kv_idx_base,
chunk_end);
commit_group();
@@ -309,12 +410,28 @@ __global__ void multi_query_append_attention_kernel(
__syncthreads();
if constexpr (!partition_kv) {
- normalize_d(o_frag, d_frag);
+ normalize_d(o_frag, d_frag);
+ }
+#ifdef DEBUG_PERCISION
+ __syncthreads();
+ if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.z == 0 &&
+ blockIdx.x == 0) {
+ for (uint32_t i = 0; i < num_frags_x; ++i) {
+ for (uint32_t j = 0; j < num_frags_y_v; ++j) {
+ printf("o_frag[%d][%d]:\n", i, j);
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.4f ", s_frag[i][j][k]);
+ }
+ printf("\n");
+ }
+ }
}
+ __syncthreads();
+#endif
if constexpr (partition_kv) {
write_o_reg_gmem_shift_smooth_quant(
o_frag,
&qo_smem,
@@ -328,11 +445,11 @@ __global__ void multi_query_append_attention_kernel(
in_scale,
q_len,
partition_kv ? q_n_stride * num_chunks : q_n_stride,
- HEAD_DIM);
+ HEAD_DIM_V);
} else {
write_o_reg_gmem_shift_smooth_quant(
o_frag,
&qo_smem,
@@ -346,7 +463,7 @@ __global__ void multi_query_append_attention_kernel(
in_scale,
q_len,
partition_kv ? q_n_stride * num_chunks : q_n_stride,
- HEAD_DIM);
+ HEAD_DIM_V);
}
@@ -387,15 +504,17 @@ template
__global__ void multi_query_append_attention_warp1_4_kernel(
- T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
+ T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
T *__restrict__ cache_v,
@@ -410,7 +529,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
- const float scale,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -421,7 +540,9 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
- constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b();
+ constexpr uint32_t num_vecs_per_head_qk =
+ HEAD_DIM_QK / num_elems_per_128b();
+ constexpr uint32_t num_vecs_per_head_v = HEAD_DIM_V / num_elems_per_128b();
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
@@ -467,24 +588,29 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
extern __shared__ uint8_t smem[];
float s_frag[num_frags_x][num_frags_z][8];
- float o_frag[num_frags_x][num_frags_y][8];
+ float o_frag[num_frags_x][num_frags_y_v][8];
float m_frag[num_frags_x][2];
float d_frag[num_frags_x][2];
- init_states(o_frag, m_frag, d_frag);
-
- const uint32_t q_n_stride = q_num_heads * HEAD_DIM;
- const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM;
- const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
- const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
- const uint32_t kv_b_stride = HEAD_DIM;
+ init_states(o_frag, m_frag, d_frag);
+
+ const uint32_t q_n_stride = q_num_heads * HEAD_DIM_V;
+ const uint32_t q_ori_n_stride = q_num_heads * HEAD_DIM_QK +
+ kv_num_heads * HEAD_DIM_QK +
+ kv_num_heads * HEAD_DIM_V;
+ const uint32_t k_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM_QK;
+ const uint32_t k_h_stride = BLOCK_SIZE * HEAD_DIM_QK;
+ const uint32_t k_b_stride = HEAD_DIM_QK;
+ const uint32_t v_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM_V;
+ const uint32_t v_h_stride = BLOCK_SIZE * HEAD_DIM_V;
+ const uint32_t v_b_stride = HEAD_DIM_V;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
- q_head_idx * HEAD_DIM +
+ q_head_idx * HEAD_DIM_QK +
tid % 8 * num_elems_per_128b();
const uint32_t o_offset = q_start_seq_id * q_n_stride +
- q_head_idx * HEAD_DIM +
+ q_head_idx * HEAD_DIM_V +
tid % 8 * num_elems_per_128b();
T *q_base_ptr = q + q_offset;
T *o_base_ptr_T = nullptr;
@@ -494,41 +620,59 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
} else {
if (ENABLE_PREFILL) {
o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride +
- chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
+ chunk_idx * q_n_stride + q_head_idx * HEAD_DIM_V +
tid % 8 * num_elems_per_128b();
} else {
o_base_ptr_T =
tmp_workspace +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
- chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
+ chunk_idx * q_n_stride + q_head_idx * HEAD_DIM_V +
tid % 8 * num_elems_per_128b();
}
}
smem_t qo_smem(smem);
- uint32_t q_smem_offset_r = smem_t::get_permuted_offset(
+ uint32_t q_smem_offset_r = smem_t::get_permuted_offset(
tid % 16, tid / 16); // 16 * 16
load_q_global_smem_multi_warps(q_base_ptr,
&qo_smem,
q_base_seq_id_this_block,
q_end,
q_ori_n_stride,
- HEAD_DIM);
+ HEAD_DIM_QK);
commit_group();
wait_group<0>();
__syncthreads();
+#ifdef DEBUG_PERCISION_DEC
+ if (tid == 0 && threadIdx.y == 0 && blockIdx.z == 0 && blockIdx.x == 0) {
+ printf("q_smem(%d * 192个bfloat16):\n", num_frags_x * 16);
+ // const uint32_t k_num = num_frags_z * 64 * HEAD_DIM / 2 * sizeof(CacheT);
+ T *q_smem_t = reinterpret_cast(qo_smem.base);
+ for (uint32_t i = 0; i < 4 * num_frags_x * 16; ++i) {
+ printf("q_smem[%d]:", (int)i);
+ for (uint32_t j = 0; j < HEAD_DIM_QK / 8; ++j) {
+ printf("[");
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.2f ", (float)q_smem_t[i * HEAD_DIM_QK + j * 8 + k]);
+ }
+ printf("]");
+ }
+ printf("\n");
+ }
+ }
+ __syncthreads();
+#endif
+ q_smem_inplace_multiply_sm_scale_multi_warps(
+ &qo_smem, softmax_scale);
- q_smem_inplace_multiply_sm_scale_multi_warps(
- &qo_smem, scale);
-
- smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
- v_smem(smem + (num_frags_x + NUM_WARP_KV * num_frags_z) * 16 * HEAD_DIM *
- sizeof(T));
+ smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM_QK * sizeof(T)),
+ v_smem(smem + (num_frags_x + NUM_WARP_KV * num_frags_z) * 16 *
+ HEAD_DIM_QK * sizeof(T));
const uint32_t num_iterations = div_up(
CAUSAL
@@ -548,34 +692,39 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
- uint32_t k_smem_offset_r = smem_t::get_permuted_offset(
+ uint32_t k_smem_offset_r = smem_t::get_permuted_offset(
wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
- uint32_t v_smem_offset_r = smem_t::get_permuted_offset(
+ uint32_t v_smem_offset_r = smem_t::get_permuted_offset(
wid * num_frags_z * 16 + tid % 16, tid / 16);
- uint32_t kv_smem_offset_w = smem_t::get_permuted_offset(
+ uint32_t k_smem_offset_w = smem_t::get_permuted_offset(
+ wid * 4 + tid / 8, tid % 8);
+ uint32_t v_smem_offset_w = smem_t::get_permuted_offset(
wid * 4 + tid / 8, tid % 8);
uint32_t kv_idx_base = chunk_start;
int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]);
- const uint32_t const_offset = kv_head_idx * kv_h_stride +
- (wid * 4 + tid / 8) * kv_b_stride +
- tid % 8 * num_elems_per_128b();
- T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset;
- T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset;
+ const uint32_t const_offset_k = kv_head_idx * k_h_stride +
+ (wid * 4 + tid / 8) * k_b_stride +
+ tid % 8 * num_elems_per_128b();
+ const uint32_t const_offset_v = kv_head_idx * v_h_stride +
+ (wid * 4 + tid / 8) * v_b_stride +
+ tid % 8 * num_elems_per_128b();
+ T *cache_k_now = cache_k + block_id * k_n_stride + const_offset_k;
+ T *cache_v_now = cache_v + block_id * v_n_stride + const_offset_v;
produce_kv_blockwise(k_smem,
- &kv_smem_offset_w,
+ &k_smem_offset_w,
&cache_k_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ k_n_stride,
+ k_h_stride,
+ k_b_stride,
kv_idx_base,
chunk_end);
commit_group();
@@ -583,15 +732,15 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
produce_kv_blockwise(v_smem,
- &kv_smem_offset_w,
+ &v_smem_offset_w,
&cache_v_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ v_n_stride,
+ v_h_stride,
+ v_b_stride,
kv_idx_base,
chunk_end);
commit_group();
@@ -600,10 +749,45 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
wait_group<1>();
__syncthreads();
-
+#ifdef DEBUG_PERCISION_DEC
+ if (tid == 0 && threadIdx.y == 0 && blockIdx.z == 0 && blockIdx.x == 0) {
+ printf("k_smem(%d * 192个bfloat16):\n", 4 * num_frags_z * 16);
+ // const uint32_t k_num = num_frags_z * 64 * HEAD_DIM / 2 *
+ // sizeof(CacheT);
+ T *k_smem_t = reinterpret_cast(k_smem.base);
+ for (uint32_t i = 0; i < num_frags_z * 16; ++i) {
+ printf("k_smem[%d]:", (int)i);
+ for (uint32_t j = 0; j < HEAD_DIM_QK / 8; ++j) {
+ printf("[");
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.2f ", (float)k_smem_t[i * HEAD_DIM_QK + j * 8 + k]);
+ }
+ printf("]");
+ }
+ printf("\n");
+ }
+ }
+ __syncthreads();
+#endif
// s = qk
- compute_qk(
+ compute_qk(
&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
+#ifdef DEBUG_PERCISION_DEC
+ __syncthreads();
+ if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.z == 0 &&
+ blockIdx.x == 0) {
+ for (uint32_t i = 0; i < num_frags_x; ++i) {
+ for (uint32_t j = 0; j < num_frags_z; ++j) {
+ printf("s_frag[%d][%d]:\n", i, j);
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.4f ", s_frag[i][j][k]);
+ }
+ printf("\n");
+ }
+ }
+ }
+ __syncthreads();
+#endif
// mask according to kv_idx and q_idx
if (iter >= mask_check_iteration) {
mask_s(q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
@@ -622,7 +806,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
}
// update m,d
- update_mdo_states(
+ update_mdo_states(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();
@@ -631,43 +815,77 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
if (block_id < 0) {
block_id = 0;
}
- cache_k_now = cache_k + block_id * kv_n_stride + const_offset;
+ cache_k_now = cache_k + block_id * k_n_stride + const_offset_k;
produce_kv_blockwise(k_smem,
- &kv_smem_offset_w,
+ &k_smem_offset_w,
&cache_k_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ k_n_stride,
+ k_h_stride,
+ k_b_stride,
kv_idx_base,
chunk_end);
commit_group();
wait_group<1>();
__syncthreads();
-
+#ifdef DEBUG_PERCISION_DEC
+ if (tid == 0 && threadIdx.y == 0 && blockIdx.z == 0 && blockIdx.x == 0) {
+ printf("v_smem(%d * 128个bfloat16):\n", 4 * num_frags_z * 16);
+ // const uint32_t k_num = num_frags_z * 64 * HEAD_DIM / 2 *
+ // sizeof(CacheT);
+ T *v_smem_t = reinterpret_cast(v_smem.base);
+ for (uint32_t i = 0; i < num_frags_z * 16; ++i) {
+ printf("v_smem[%d]:", (int)i);
+ for (uint32_t j = 0; j < HEAD_DIM_V / 8; ++j) {
+ printf("[");
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.2f ", (float)v_smem_t[i * HEAD_DIM_V + j * 8 + k]);
+ }
+ printf("]");
+ }
+ printf("\n");
+ }
+ }
+ __syncthreads();
+#endif
// compute sfm*v
- compute_sfm_v(
+ compute_sfm_v(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag);
__syncthreads();
-
- cache_v_now = cache_v + block_id * kv_n_stride + const_offset;
+#ifdef DEBUG_PERCISION_DEC
+ __syncthreads();
+ if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.z == 0 &&
+ blockIdx.x == 0) {
+ for (uint32_t i = 0; i < num_frags_x; ++i) {
+ for (uint32_t j = 0; j < num_frags_y_v; ++j) {
+ printf("o_frag[%d][%d]:\n", i, j);
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.4f ", s_frag[i][j][k]);
+ }
+ printf("\n");
+ }
+ }
+ }
+ __syncthreads();
+#endif
+ cache_v_now = cache_v + block_id * v_n_stride + const_offset_v;
produce_kv_blockwise(v_smem,
- &kv_smem_offset_w,
+ &v_smem_offset_w,
&cache_v_now,
kv_head_idx,
- kv_n_stride,
- kv_h_stride,
- kv_b_stride,
+ v_n_stride,
+ v_h_stride,
+ v_b_stride,
kv_idx_base,
chunk_end);
commit_group();
@@ -675,19 +893,34 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
wait_group<0>();
__syncthreads();
- merge_block_res_v2(
+ merge_block_res_v2(
o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid);
if (num_chunks_this_seq <= 1) {
- normalize_d(o_frag, d_frag);
+ normalize_d(o_frag, d_frag);
}
-
+#ifdef DEBUG_PERCISION_DEC
+ __syncthreads();
+ if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.z == 0 &&
+ blockIdx.x == 0) {
+ for (uint32_t i = 0; i < num_frags_x; ++i) {
+ for (uint32_t j = 0; j < num_frags_y_v; ++j) {
+ printf("o_frag[%d][%d]:\n", i, j);
+ for (uint32_t k = 0; k < 8; ++k) {
+ printf("%.4f ", s_frag[i][j][k]);
+ }
+ printf("\n");
+ }
+ }
+ }
+ __syncthreads();
+#endif
// write o
// [num_frags_x, 16, num_frags_y, 16]
if (num_chunks_this_seq <= 1) {
write_o_reg_gmem_multi_warps_shift_smooth_quant(
o_frag,
&qo_smem,
@@ -701,11 +934,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
in_scale,
q_len,
q_n_stride,
- HEAD_DIM);
+ HEAD_DIM_V);
} else {
write_o_reg_gmem_multi_warps_shift_smooth_quant(
o_frag,
&qo_smem,
@@ -719,7 +952,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
in_scale,
q_len,
q_n_stride * num_chunks,
- HEAD_DIM);
+ HEAD_DIM_V);
}
if (num_chunks_this_seq > 1) {
@@ -757,7 +990,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
template ;
if (smem_size >= 48 * 1024) {
@@ -853,11 +1090,13 @@ void MultiQueryAppendAttention(
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
- HEAD_DIM,
+ HEAD_DIM_QK,
+ HEAD_DIM_V,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
- num_frags_y,
+ num_frags_y_qk,
+ num_frags_y_v,
OUT_NV_TYPE,
ENABLE_PREFILL>;
if (smem_size >= 48 * 1024) {
@@ -885,7 +1124,7 @@ void MultiQueryAppendAttention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -899,9 +1138,10 @@ void MultiQueryAppendAttention(
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (ENABLE_PREFILL) {
- tmp_workspace = allocator->Allocate(
- phi::SizeOf(qkv.dtype()) *
- static_cast(token_num * num_chunks * num_heads * HEAD_DIM));
+ tmp_workspace =
+ allocator->Allocate(phi::SizeOf(qkv.dtype()) *
+ static_cast(token_num * num_chunks *
+ num_heads * HEAD_DIM_V));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast(token_num * num_chunks * num_heads));
@@ -912,7 +1152,7 @@ void MultiQueryAppendAttention(
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast(speculate_max_draft_token_num * bsz *
- num_chunks * num_heads * HEAD_DIM));
+ num_chunks * num_heads * HEAD_DIM_V));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast(speculate_max_draft_token_num * bsz *
@@ -942,7 +1182,7 @@ void MultiQueryAppendAttention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -955,14 +1195,14 @@ void MultiQueryAppendAttention(
// merge
constexpr int vec_size = num_elems_per_128b();
if (is_decoder) {
- constexpr int blockx = HEAD_DIM / vec_size;
+ constexpr int blockx = HEAD_DIM_V / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_decoder_kernel
<<>>(
@@ -987,9 +1227,9 @@ void MultiQueryAppendAttention(
num_chunks,
num_heads,
chunk_size,
- HEAD_DIM);
+ HEAD_DIM_V);
} else {
- constexpr int blockx = HEAD_DIM / vec_size;
+ constexpr int blockx = HEAD_DIM_V / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads); // 128k is too large
@@ -997,7 +1237,7 @@ void MultiQueryAppendAttention(
merge_multi_chunks_v2_kernel
<<>>(
@@ -1022,7 +1262,7 @@ void MultiQueryAppendAttention(
num_chunks,
num_heads,
chunk_size,
- HEAD_DIM,
+ HEAD_DIM_V,
token_num,
speculate_max_draft_token_num);
}
@@ -1030,8 +1270,9 @@ void MultiQueryAppendAttention(
} else {
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV;
constexpr uint32_t smem_size =
- (num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * HEAD_DIM *
- sizeof(T);
+ ((num_frags_x + NUM_WARP_KV * num_frags_z) * HEAD_DIM_QK +
+ NUM_WARP_KV * num_frags_z * HEAD_DIM_V) *
+ 16 * sizeof(T);
auto split_kv_kernel =
multi_query_append_attention_warp1_4_kernel;
if (smem_size >= 48 * 1024) {
@@ -1074,11 +1317,13 @@ void MultiQueryAppendAttention(
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
- HEAD_DIM,
+ HEAD_DIM_QK,
+ HEAD_DIM_V,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
- num_frags_y,
+ num_frags_y_qk,
+ num_frags_y_v,
OUT_NV_TYPE,
ENABLE_PREFILL>;
if (smem_size >= 48 * 1024) {
@@ -1106,7 +1351,7 @@ void MultiQueryAppendAttention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1121,7 +1366,7 @@ void MultiQueryAppendAttention(
if (is_decoder) {
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
- static_cast(bsz * num_chunks * num_heads * HEAD_DIM));
+ static_cast(bsz * num_chunks * num_heads * HEAD_DIM_V));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast(bsz * num_chunks * num_heads));
@@ -1133,7 +1378,7 @@ void MultiQueryAppendAttention(
tmp_workspace =
allocator->Allocate(phi::SizeOf(qkv.dtype()) *
static_cast(token_num * num_chunks *
- num_heads * HEAD_DIM));
+ num_heads * HEAD_DIM_V));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast(token_num * num_chunks * num_heads));
@@ -1144,7 +1389,7 @@ void MultiQueryAppendAttention(
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast(speculate_max_draft_token_num * bsz *
- num_chunks * num_heads * HEAD_DIM));
+ num_chunks * num_heads * HEAD_DIM_V));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast(speculate_max_draft_token_num * bsz *
@@ -1174,7 +1419,7 @@ void MultiQueryAppendAttention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1188,14 +1433,14 @@ void MultiQueryAppendAttention(
// merge
constexpr int vec_size = num_elems_per_128b();
if (is_decoder) {
- constexpr int blockx = HEAD_DIM / vec_size;
+ constexpr int blockx = HEAD_DIM_V / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_decoder_kernel
<<>>(
@@ -1220,17 +1465,16 @@ void MultiQueryAppendAttention(
num_chunks,
num_heads,
chunk_size,
- HEAD_DIM);
+ HEAD_DIM_V);
} else {
- constexpr int blockx = HEAD_DIM / vec_size;
+ constexpr int blockx = HEAD_DIM_V / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
- dim3 grids_merge(min(sm_count * 4, token_num),
- num_heads);
+ dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel
<<>>(
@@ -1255,7 +1499,7 @@ void MultiQueryAppendAttention(
num_chunks,
num_heads,
chunk_size,
- HEAD_DIM,
+ HEAD_DIM_V,
token_num,
speculate_max_draft_token_num);
}
@@ -1265,37 +1509,39 @@ void MultiQueryAppendAttention(
template
void CascadeAppendAttentionC16Kernel(
- const AppendAttnMetaData& meta_data,
- const paddle::Tensor& qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
- const paddle::Tensor&
- cache_k, // [max_block_num, num_heads, block_size, head_dim]
- const paddle::Tensor&
- cache_v, // [max_block_num, num_heads, head_dim, block_size]
- const paddle::optional& attn_mask,
- const paddle::optional&
- cache_k_scale, // [num_kv_heads, head_dim]
- const paddle::optional&
- cache_v_scale, // [num_kv_heads, head_dim]
- const paddle::optional&
- cache_k_zp, // [num_kv_heads, head_dim]
- const paddle::optional&
- cache_v_zp, // [num_kv_heads, head_dim]
- const paddle::optional&
- shift_bias, // [num_kv_heads, head_dim]
- const paddle::optional&
- smooth_weight, // [num_kv_heads, head_dim]
- const paddle::Tensor& seq_lens_q,
- const paddle::Tensor& seq_lens_kv,
- const paddle::Tensor& seq_lens_encoder,
- const paddle::Tensor& padding_offsets,
- const paddle::Tensor& cum_offsets,
- const paddle::Tensor& block_table,
- const paddle::Tensor& batch_ids,
- const paddle::Tensor& tile_ids_per_batch,
+ const AppendAttnMetaData &meta_data,
+ const paddle::Tensor
+ &qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
+ const paddle::Tensor
+ &cache_k, // [max_block_num, num_heads, block_size, head_dim]
+ const paddle::Tensor
+ &cache_v, // [max_block_num, num_heads, head_dim, block_size]
+ const paddle::optional &attn_mask,
+ const paddle::optional
+ &cache_k_scale, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &cache_v_scale, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &cache_k_zp, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &cache_v_zp, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &shift_bias, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &smooth_weight, // [num_kv_heads, head_dim]
+ const paddle::Tensor &seq_lens_q,
+ const paddle::Tensor &seq_lens_kv,
+ const paddle::Tensor &seq_lens_encoder,
+ const paddle::Tensor &padding_offsets,
+ const paddle::Tensor &cum_offsets,
+ const paddle::Tensor &block_table,
+ const paddle::Tensor &batch_ids,
+ const paddle::Tensor &tile_ids_per_batch,
const int num_blocks,
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -1303,14 +1549,15 @@ void CascadeAppendAttentionC16Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
- cudaStream_t& stream,
- paddle::Tensor* out) {
+ cudaStream_t &stream,
+ paddle::Tensor *out) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
const auto num_heads = meta_data.q_num_heads;
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
- const auto head_dim = meta_data.head_dims;
+ const auto head_dim_qk = meta_data.head_dims;
+ const auto head_dim_v = meta_data.head_dims_v;
DISPATCH_CAUSAL(
causal,
@@ -1322,46 +1569,51 @@ void CascadeAppendAttentionC16Kernel(
group_size,
GROUP_SIZE,
{DISPATCH_HEAD_DIM(
- head_dim,
- HEAD_DIM,
- {DISPATCH_BLOCK_SIZE(
- block_size,
- BLOCK_SIZE,
- {DISPATCH_BLOCKSHAPE_Q(
- block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, {
- MultiQueryAppendAttention(
- meta_data,
- qkv,
- cache_k,
- cache_v,
- attn_mask,
- shift_bias,
- smooth_weight,
- seq_lens_q,
- seq_lens_kv,
- seq_lens_encoder,
- padding_offsets,
- cum_offsets,
- block_table,
- batch_ids,
- tile_ids_per_batch,
- num_blocks,
- max_seq_len,
- max_dec_len,
- quant_max_bound,
- quant_min_bound,
- in_scale,
- speculate_max_draft_token_num,
- is_decoder,
- stream,
- out);
- })})})})})})
+ head_dim_qk,
+ HEAD_DIM_QK,
+ {DISPATCH_HEAD_DIM(
+ head_dim_v,
+ HEAD_DIM_V,
+ {DISPATCH_BLOCK_SIZE(
+ block_size,
+ BLOCK_SIZE,
+ {DISPATCH_BLOCKSHAPE_Q(
+ block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, {
+ MultiQueryAppendAttention(
+ meta_data,
+ qkv,
+ cache_k,
+ cache_v,
+ attn_mask,
+ shift_bias,
+ smooth_weight,
+ seq_lens_q,
+ seq_lens_kv,
+ seq_lens_encoder,
+ padding_offsets,
+ cum_offsets,
+ block_table,
+ batch_ids,
+ tile_ids_per_batch,
+ num_blocks,
+ max_seq_len,
+ max_dec_len,
+ softmax_scale,
+ quant_max_bound,
+ quant_min_bound,
+ in_scale,
+ speculate_max_draft_token_num,
+ is_decoder,
+ stream,
+ out);
+ })})})})})})})
}
diff --git a/csrc/gpu/append_attn/append_attention_c4_impl.cuh b/csrc/gpu/append_attn/append_attention_c4_impl.cuh
index 7d49de3966e0..fac1baf6f4c2 100644
--- a/csrc/gpu/append_attn/append_attention_c4_impl.cuh
+++ b/csrc/gpu/append_attn/append_attention_c4_impl.cuh
@@ -51,7 +51,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
- const float scale,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -189,7 +189,7 @@ __global__ void multi_query_append_attention_c4_kernel(
__syncthreads();
q_smem_inplace_multiply_sm_scale(&qo_smem,
- scale);
+ softmax_scale);
T cache_k_scale_frag[num_frags_y][4];
T cache_k_zp_frag[num_frags_y][4];
@@ -509,7 +509,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
- const float scale,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -649,7 +649,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
__syncthreads();
q_smem_inplace_multiply_sm_scale_multi_warps(
- &qo_smem, scale);
+ &qo_smem, softmax_scale);
T cache_k_scale_frag[num_frags_y][4];
T cache_k_zp_frag[num_frags_y][4];
@@ -970,6 +970,7 @@ void MultiQueryAppendC4Attention(
const int num_blocks_x_cpu,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -994,8 +995,6 @@ void MultiQueryAppendC4Attention(
auto *allocator = paddle::GetAllocator(qkv.place());
- const float scale = 1.f / sqrt(HEAD_DIM);
-
if constexpr (NUM_WARP_Q == 4) {
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
constexpr uint32_t smem_size =
@@ -1091,7 +1090,7 @@ void MultiQueryAppendC4Attention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1154,7 +1153,7 @@ void MultiQueryAppendC4Attention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1336,7 +1335,7 @@ void MultiQueryAppendC4Attention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1412,7 +1411,7 @@ void MultiQueryAppendC4Attention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1533,6 +1532,7 @@ void CascadeAppendAttentionC4Kernel(
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -1597,6 +1597,7 @@ void CascadeAppendAttentionC4Kernel(
num_blocks,
max_seq_len,
max_dec_len,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
diff --git a/csrc/gpu/append_attn/append_attention_c8_impl.cuh b/csrc/gpu/append_attn/append_attention_c8_impl.cuh
index e0ede51a9c81..df2357bb192b 100644
--- a/csrc/gpu/append_attn/append_attention_c8_impl.cuh
+++ b/csrc/gpu/append_attn/append_attention_c8_impl.cuh
@@ -32,7 +32,7 @@ template
__global__ void multi_query_append_attention_c8_kernel(
- T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
+ T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
@@ -49,7 +49,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
- const float scale,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -172,7 +172,7 @@ __global__ void multi_query_append_attention_c8_kernel(
__syncthreads();
q_smem_inplace_multiply_sm_scale(&qo_smem,
- scale);
+ softmax_scale);
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
@@ -206,8 +206,7 @@ __global__ void multi_query_append_attention_c8_kernel(
uint32_t k_smem_offset_w =
smem_t::get_permuted_offset(
- wid * 4 + tid / 8,
- tid % 8);
+ wid * 4 + tid / 8, tid % 8);
uint32_t v_smem_offset_w =
smem_t::get_permuted_offset(
wid * 8 + tid / 4, tid % 4); // 4 * 128 / 8 = 64
@@ -338,7 +337,6 @@ __global__ void multi_query_append_attention_c8_kernel(
chunk_end,
const_v_offset);
commit_group();
-
}
wait_group<0>();
__syncthreads();
@@ -434,7 +432,7 @@ template
__global__ void multi_query_append_attention_c8_warp1_4_kernel(
- T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
+ T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
@@ -451,7 +449,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
- const float scale,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -575,7 +573,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
__syncthreads();
q_smem_inplace_multiply_sm_scale_multi_warps(
- &qo_smem, scale);
+ &qo_smem, softmax_scale);
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
@@ -610,12 +608,10 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
uint32_t k_smem_offset_w =
smem_t::get_permuted_offset(
- wid * 4 + tid / 8,
- tid %
- 8);
+ wid * 4 + tid / 8, tid % 8);
uint32_t v_smem_offset_w =
smem_t::get_permuted_offset(
- wid * 8 + tid / 4, tid % 4);
+ wid * 8 + tid / 4, tid % 4);
uint32_t kv_idx_base = chunk_start;
const uint32_t const_k_offset = kv_head_idx * kv_h_stride +
@@ -805,7 +801,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
if (qo_idx - q_start_seq_id < q_len) {
-
uint32_t offset;
if (ENABLE_PREFILL) {
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
@@ -857,6 +852,7 @@ void MultiQueryAppendC8Attention(
const int num_blocks_x_cpu,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -881,8 +877,6 @@ void MultiQueryAppendC8Attention(
auto *allocator = paddle::GetAllocator(qkv.place());
- const float scale = 1.f / sqrt(HEAD_DIM);
-
if constexpr (NUM_WARP_Q == 4) {
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
constexpr uint32_t smem_size =
@@ -963,7 +957,7 @@ void MultiQueryAppendC8Attention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1020,7 +1014,7 @@ void MultiQueryAppendC8Attention(
max_seq_len,
max_dec_len,
max_block_num_per_seq,
- scale,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -1069,8 +1063,7 @@ void MultiQueryAppendC8Attention(
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
- dim3 grids_merge(min(sm_count * 4, token_num),
- num_heads);
+ dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel
void CascadeAppendAttentionC8Kernel(
- const AppendAttnMetaData& meta_data,
- const paddle::Tensor& qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
- const paddle::Tensor&
- cache_k, // [max_block_num, num_heads, block_size, head_dim]
- const paddle::Tensor&
- cache_v, // [max_block_num, num_heads, head_dim, block_size]
- const paddle::optional& attn_mask,
- const paddle::optional&
- cache_k_scale, // [num_kv_heads, head_dim]
- const paddle::optional&
- cache_v_scale, // [num_kv_heads, head_dim]
- const paddle::optional&
- cache_k_zp, // [num_kv_heads, head_dim]
- const paddle::optional&
- cache_v_zp, // [num_kv_heads, head_dim]
- const paddle::optional&
- shift_bias, // [num_kv_heads, head_dim]
- const paddle::optional&
- smooth_weight, // [num_kv_heads, head_dim]
- const paddle::Tensor& seq_lens_q,
- const paddle::Tensor& seq_lens_kv,
- const paddle::Tensor& seq_lens_encoder,
- const paddle::Tensor& padding_offsets,
- const paddle::Tensor& cum_offsets,
- const paddle::Tensor& block_table,
- const paddle::Tensor& batch_ids,
- const paddle::Tensor& tile_ids_per_batch,
+ const AppendAttnMetaData &meta_data,
+ const paddle::Tensor
+ &qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
+ const paddle::Tensor
+ &cache_k, // [max_block_num, num_heads, block_size, head_dim]
+ const paddle::Tensor
+ &cache_v, // [max_block_num, num_heads, head_dim, block_size]
+ const paddle::optional &attn_mask,
+ const paddle::optional
+ &cache_k_scale, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &cache_v_scale, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &cache_k_zp, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &cache_v_zp, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &shift_bias, // [num_kv_heads, head_dim]
+ const paddle::optional
+ &smooth_weight, // [num_kv_heads, head_dim]
+ const paddle::Tensor &seq_lens_q,
+ const paddle::Tensor &seq_lens_kv,
+ const paddle::Tensor &seq_lens_encoder,
+ const paddle::Tensor &padding_offsets,
+ const paddle::Tensor &cum_offsets,
+ const paddle::Tensor &block_table,
+ const paddle::Tensor &batch_ids,
+ const paddle::Tensor &tile_ids_per_batch,
const int num_blocks,
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -1379,8 +1373,8 @@ void CascadeAppendAttentionC8Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
- cudaStream_t& stream,
- paddle::Tensor* out) {
+ cudaStream_t &stream,
+ paddle::Tensor *out) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
@@ -1434,6 +1428,7 @@ void CascadeAppendAttentionC8Kernel(
num_blocks,
max_seq_len,
max_dec_len,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
diff --git a/csrc/gpu/append_attn/append_attention_kernel.h b/csrc/gpu/append_attn/append_attention_kernel.h
index b0fabcf893d3..b34c2a044733 100644
--- a/csrc/gpu/append_attn/append_attention_kernel.h
+++ b/csrc/gpu/append_attn/append_attention_kernel.h
@@ -49,6 +49,7 @@ void CascadeAppendAttentionC16Kernel(
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -92,6 +93,7 @@ void CascadeAppendAttentionC8Kernel(
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -135,6 +137,7 @@ void CascadeAppendAttentionC4Kernel(
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -179,6 +182,7 @@ void CascadeAppendAttentionKernel(
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
+ const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -212,6 +216,7 @@ void CascadeAppendAttentionKernel(
block_shape_q,
max_seq_len,
max_dec_len,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -245,6 +250,7 @@ void CascadeAppendAttentionKernel(
block_shape_q,
max_seq_len,
max_dec_len,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
@@ -278,6 +284,7 @@ void CascadeAppendAttentionKernel(
block_shape_q,
max_seq_len,
max_dec_len,
+ softmax_scale,
quant_max_bound,
quant_min_bound,
in_scale,
diff --git a/csrc/gpu/append_attn/decoder_write_cache_with_rope_impl.cuh b/csrc/gpu/append_attn/decoder_write_cache_with_rope_impl.cuh
index 1a8e73759022..5fbb53f05801 100644
--- a/csrc/gpu/append_attn/decoder_write_cache_with_rope_impl.cuh
+++ b/csrc/gpu/append_attn/decoder_write_cache_with_rope_impl.cuh
@@ -122,6 +122,91 @@ __global__ void append_decode_cache_T_rope_kernel(
}
}
+template
+__global__ void append_decode_cache_T_kernel(
+ const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
+ // head_size]
+ T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
+ // head_size // 2]
+ T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
+ // head_size // 2]
+ const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
+ const int* __restrict__ padding_offsets, // [num_tokens]
+ const int* __restrict__ cum_offsets,
+ const int* __restrict__ seq_lens, // [bsz]
+ const int* __restrict__ seq_lens_encoder, // [bsz]
+ const int max_seq_len,
+ const int max_blocks_per_seq,
+ const int num_heads,
+ const int head_size_qk,
+ const int head_size_v,
+ const int block_size,
+ const uint32_t elem_cnt,
+ const int kv_num_heads) {
+ using LoadT = AlignedVector;
+ using LoadBiasT = AlignedVector;
+ using LoadKVT = AlignedVector;
+ constexpr int HalfVecSize = VecSize / 2;
+ using LoadEmbT = AlignedVector;
+ LoadT src_vec;
+ LoadBiasT out_vec;
+ LoadKVT cache_vec;
+
+ int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
+ // const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
+ const uint32_t hidden_size_q = num_heads * head_size_qk;
+ const uint32_t hidden_size_k = kv_num_heads * head_size_qk;
+ const uint32_t hidden_size_v = kv_num_heads * head_size_v;
+ const int64_t hidden_size = hidden_size_q + hidden_size_k + hidden_size_v;
+ const uint32_t offset = kv_num_heads * (head_size_qk + head_size_v);
+ // const int64_t offset = 2 * hidden_size;
+ // const int half_head_size = head_size / 2;
+ for (int32_t linear_index = global_thread_idx * VecSize,
+ step = gridDim.x * blockDim.x * VecSize;
+ linear_index < elem_cnt;
+ linear_index += step) {
+ const int ori_bi = linear_index / offset;
+ const int bias = linear_index % offset;
+ const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
+ if (seq_lens_encoder[ori_bi] > 0) return;
+ const int write_seq_id = seq_lens[ori_bi];
+
+ if (write_seq_id == 0) continue;
+
+ const int* block_table_now = nullptr;
+
+ block_table_now = block_tables + ori_bi * max_blocks_per_seq;
+ const int block_idx = block_table_now[write_seq_id / block_size];
+ const int block_offset = write_seq_id % block_size;
+
+ if (bias < hidden_size_k) {
+ const uint32_t qkv_bias = bias;
+ const uint32_t hi = qkv_bias / head_size_qk;
+ const uint32_t h_bias = qkv_bias % head_size_qk;
+ const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * head_size_qk +
+ hi * block_size * head_size_qk +
+ block_offset * head_size_qk + h_bias;
+ const uint32_t ori_idx =
+ start_token_idx * hidden_size +
+ hidden_size_q + qkv_bias;
+ Load(&qkv[ori_idx], &src_vec);
+ Store(src_vec, &key_cache[tgt_idx]);
+ } else {
+ const uint32_t qkv_bias = bias - hidden_size_k;
+ const uint32_t hi = qkv_bias / head_size_v;
+ const uint32_t h_bias = qkv_bias % head_size_v;
+ const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * head_size_v +
+ hi * block_size * head_size_v +
+ block_offset * head_size_v + h_bias;
+ const uint32_t ori_idx =
+ start_token_idx * hidden_size +
+ hidden_size_q + hidden_size_k + qkv_bias;
+ Load(&qkv[ori_idx], &src_vec);
+ Store(src_vec, &value_cache[tgt_idx]);
+ }
+ }
+}
+
template
__global__ void append_decode_cache_T_rope_kernel(
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
diff --git a/csrc/gpu/append_attn/decoder_write_cache_with_rope_kernel.cu b/csrc/gpu/append_attn/decoder_write_cache_with_rope_kernel.cu
index ee0cd57e307c..08483feb2a5c 100644
--- a/csrc/gpu/append_attn/decoder_write_cache_with_rope_kernel.cu
+++ b/csrc/gpu/append_attn/decoder_write_cache_with_rope_kernel.cu
@@ -15,6 +15,54 @@
#include "decoder_write_cache_with_rope_kernel.h"
#include "utils.cuh"
+
+template
+void DecoderWriteCacheKV(const AppendAttnMetaData& meta_data,
+ const paddle::Tensor& qkv,
+ const paddle::Tensor& seq_lens,
+ const paddle::Tensor& seq_lens_encoder,
+ const paddle::Tensor& padding_offsets,
+ const paddle::Tensor& cum_offsets,
+ const paddle::Tensor& block_tables,
+ const int max_seq_len,
+ cudaStream_t& stream,
+ paddle::Tensor* key_cache_out,
+ paddle::Tensor* value_cache_out) {
+ auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
+ auto bsz = meta_data.batch_size;
+ auto block_size = meta_data.block_size;
+ auto head_dim_qk = meta_data.head_dims;
+ auto head_dim_v = meta_data.head_dims_v;
+ auto num_heads = meta_data.q_num_heads;
+ auto kv_num_heads = meta_data.kv_num_heads;
+ const uint32_t elem_nums = bsz * kv_num_heads * (head_dim_qk + head_dim_v);
+
+ constexpr int PackSize = 16 / sizeof(T);
+ const int pack_num = elem_nums / PackSize;
+ const int blocksize = 128;
+ int grid_size = 1;
+ GetNumBlocks<128>(pack_num, &grid_size);
+
+ append_decode_cache_T_kernel
+ <<>>(
+ reinterpret_cast(const_cast(qkv.data())),
+ reinterpret_cast(key_cache_out->data()),
+ reinterpret_cast(value_cache_out->data()),
+ block_tables.data(),
+ padding_offsets.data(),
+ cum_offsets.data(),
+ seq_lens.data(),
+ seq_lens_encoder.data(),
+ max_seq_len,
+ max_blocks_per_seq,
+ num_heads,
+ head_dim_qk,
+ head_dim_v,
+ block_size,
+ elem_nums,
+ kv_num_heads);
+}
+
template
void append_decode_cache_rope(const QKV_TYPE* qkv,
T* key_cache,
@@ -449,115 +497,125 @@ void DecoderWriteCacheWithRoPEKernel(
auto num_heads = meta_data.q_num_heads;
auto kv_num_heads = meta_data.kv_num_heads;
- const float* cos_emb =
- rotary_embs ? rotary_embs.get().data() : nullptr;
- const float* sin_emb;
if (rotary_embs) {
- sin_emb =
+ const float* cos_emb = rotary_embs.get().data();
+ const float* sin_emb =
use_neox_rotary_style
? rotary_embs.get().data() + max_seq_len * dim_head
: rotary_embs.get().data() + max_seq_len * dim_head / 2;
- }
- if (cache_quant_type_str == "none") {
- append_decode_cache_rope(
- reinterpret_cast(qkv_ptr),
- reinterpret_cast(key_cache_out->data()),
- reinterpret_cast(value_cache_out->data()),
- reinterpret_cast(qkv_out->data()),
- block_tables.data(),
- padding_offsets.data(),
- cum_offsets.data(),
- seq_lens.data(),
- seq_lens_encoder.data(),
- cos_emb,
- sin_emb,
- qkv_out_scales ? qkv_out_scales.get().data() : nullptr,
- qkv_biases ? reinterpret_cast(
- const_cast(qkv_biases.get().data()))
- : nullptr,
- max_seq_len,
- max_blocks_per_seq,
- num_heads,
- kv_num_heads,
- dim_head,
- block_size,
- bsz,
- stream,
- use_neox_rotary_style);
- } else if (cache_quant_type_str == "cache_int8") {
- append_decode_cache_int8_rope(
- reinterpret_cast(qkv_ptr),
- key_cache_out->data(),
- value_cache_out->data(),
- reinterpret_cast(qkv_out->data()),
- block_tables.data(),
- padding_offsets.data(),
- cum_offsets.data(),
- seq_lens.data(),
- seq_lens_encoder.data(),
- cos_emb,
- sin_emb,
- qkv_out_scales ? qkv_out_scales.get().data() : nullptr,
- qkv_biases ? reinterpret_cast(
- const_cast(qkv_biases.get().data()))
- : nullptr,
- cache_k_scale ? reinterpret_cast(
- const_cast(cache_k_scale.get().data()))
- : nullptr,
- cache_v_scale ? reinterpret_cast(
- const_cast(cache_v_scale.get().data()))
- : nullptr,
- max_seq_len,
- max_blocks_per_seq,
- num_heads,
- kv_num_heads,
- dim_head,
- block_size,
- bsz,
- stream,
- use_neox_rotary_style);
- } else if (cache_quant_type_str == "cache_int4_zp") {
- append_decode_cache_int4_rope(
- reinterpret_cast(qkv_ptr),
- key_cache_out->data(),
- value_cache_out->data(),
- reinterpret_cast(const_cast(qkv_out->data())),
- block_tables.data(),
- padding_offsets.data(),
- cum_offsets.data(),
- seq_lens.data