Skip to content

Commit

Permalink
support deepseek-v3
Browse files Browse the repository at this point in the history
update 0113

support head_dim=192,256 for append_attn c16

attention run

refine code

add softmax_scale

support weight_only_int8

refine code

support tp

delete test_append_attn

add splited fused_moe from ziyuan

add deepseek-v3 class

fix repe for deepseek-v3

fix wint8 precision and refine code

fix wint4, big diff

add e_score_correction_bias

fix head_dim

fix v3 verify

[AutoParallel] open tensor_fusion for benchmark (PaddlePaddle#9749)

* open tensor_fusion for benchmark

fix loraga merge (PaddlePaddle#9765)

* fix loraga merge

* change sign

Fix ernie ci auto trainer error (PaddlePaddle#9758)

* [AutoParallel]:fix ernine auto_trainer error

* [AutoParallel]:fix ernine auto_trainer error

* [AutoParallel]:fix ernine auto_trainer error

* [AutoParallel]:fix ernine auto_trainer error

* [AutoParallel]:fix ernine auto_trainer error

* [AutoParallel]:fix ernine auto_trainer error

* [AutoParallel]:fix ernine auto_trainer error

* [AutoParallel]:fix ernine auto_trainer error

* [AutoParallel]:fix ernine auto_trainer error

* [AutoParallel]:fix ernine auto_trainer error

* [AutoParallel]:fix ernine auto_trainer error

* [AutoParallel]:fix ernine auto_trainer error

* [AutoParallel]:fix ernine auto_trainer error

* Update run_pretrain_auto.py

Update README.md (PaddlePaddle#9766)

* Update README.md

[BugFix] Fix matryoshka norm loss (PaddlePaddle#9774)

* fix matryoshka norm

[Distributed] support fuse optimizer (PaddlePaddle#9519) (PaddlePaddle#9777)

Update register_sequence_parallel_allreduce_hooks (PaddlePaddle#9782)

* fix sequence parallel

* update register_sequence_parallel_allreduce_hooks

* update fuse_sequence_parallel_allreduce

Fix ce error (PaddlePaddle#9783)

* [AutoParallel]:fix ci error

* [AutoParallel]:fix ci error

fix (PaddlePaddle#9779)

[MoE] fix expert parallel (PaddlePaddle#9760)

* fix moe uc

fix dpo pp criterion (PaddlePaddle#9786)

[Infer] Add pir_model path for server infer. (PaddlePaddle#9790)

fix d2s

fix v3 verify

support qk_head_dim != v_head_dim

support fp8 batch gemm on cutlass3.x

upgrade cutlass version for block_wise fp8 gemm

change cutlass commit to ckl117 group_wise branch

support fp8 block gemm, but private cutlass commit, and TODO: update fp8 dual gemm api on cutlass3.x

support auto tune fp8 block gemm code

update cutlass to v3.7.0, todo: support block gemm based on v3.7.0

support block gemm on cutlass v3.7.0 commit

code check

code check

check dynamic_quant

ad block builder dir

rename group_quant

fix wint8 v_head_dim

fix rope

fix qwen2

mla use position_ids only

remove control flow

remove gpu concat

fix norm weight dtype

remove all_reduce in fused_moe

part support fp8

check group_quant and fake fp8

check

support block gemm

[LLM] support flash device on static model (PaddlePaddle#9619) (PaddlePaddle#9787)

* [LLM] support flash device on static model

* [LLM] adapt pdc sdk

[LLM Benchmark]update scripts (PaddlePaddle#9722)

* add no_proxy & del paddlenlp_ops

* update timeout for dpo

* fix sequence_parallel

* add timeout

* add Total_Tokens_per_second_per_gpu

* fix Tokens_per_second_per_gpu

* update Total_Tokens_per_second_per_gpu

mergekit gpu 1226 (PaddlePaddle#9702)

* mergekit gpu 1226

* merge model gpu

* merge gpu

* add lora model

* change valueerror

* add lora

* gpu test

[LLM] merge code from fastdeploy (PaddlePaddle#9791)

* [LLM] update llm server dockerfiles

* merge code from fastdeploy

[Inference] Support eagle for llama (PaddlePaddle#9812)

[CI] Fix ci of small models (PaddlePaddle#9633)

[Trainer] Wrap model when lora is ON and only do evaluation. (PaddlePaddle#9803)

[README] Update README.md for documention (PaddlePaddle#9785)

* Update README.md

* Update README.md

* Update README_en.md

fix static run

wint8 and fake-fp8, todo: support data type does not match

support fp8, but ffn1 and moe in wint8

support ffn1 fp8 block gemm

done ffn1 fp8 block gemm

block gemm done

block gemm support batch

refine rope code

compute position_ids use custom op

fix split_param (PaddlePaddle#9817)

[LLM] Update model convert and fix TP for deepseekv3 (PaddlePaddle#9797)

* fix model convert and tp in MoEMLP

* fix tp_action filter

* update convert accoding to num_nextn_predict_layers

* add deepseek-R1

fuse rope

fix macro

fix mixtral

set_state_dict block_wise weight

support fp8 per tensor network, no support scale Tensor for tensor gemm

deepseek-v3 fp8 tensor gemm network, but precision fault

add triton fp8 fused_moe kernel

fix moe triton kernel

add moe triton kernel

fix

fix fp8 block gemm precision

moe triton fp8 network

support moe triton and precision correct, but shared ffn1 ffn2 incorrect

fp8 block network, no check shared ffn1-ffn2 in v2-lite

delete wint8 in fake

delete some useless code and verify per tensor net with in qkv outlinear ffn1 ffn2, but triton moe don't match api

fp8 block quant when load model, and code check

fix tokenizer and qwen

[AutoParallel] add sharding tensor_fusion save load switch (PaddlePaddle#9810)

* support tensor_fusion save load

* apply suggestions from code review

修复benchmark多机任务异常退出的处理 (PaddlePaddle#9651)

* 修复benchmark多机任务异常退出的处理

* fix bug

* update

Fix LLAMA arg parsing bug in pp (PaddlePaddle#9806)

[Readme] Update mixtral.md (PaddlePaddle#9829)

[XPU] Support empty_cache on XPUs (PaddlePaddle#9789)

* [XPU] Support empty_cache on XPUs

* warn if current device doesn't support

[Inference] Fix multibatch inference (PaddlePaddle#9831)

* fix batch infra

* fix deepseekv2 infra

Fix position_ids for infra  (PaddlePaddle#9841)

fix moe diff due to e_score_correction_bias

fix fast tokenizer

[LLM] Add pipeline and flashmask for Qwen2Moe and Deepseek (PaddlePaddle#9827)

* add modleing_pp

* add modleing_pp for qwen2moe

* add flashmask and pp for Qwen2MoE and Deepseek

* remove

* fix fast_tokenizer save

* update for topk_weight of noaux_tc

* fix for flashmask

* add use_expert_parallel for pretrain

* fix tokenizer test

[Mergekit]update & add LoRA merge (PaddlePaddle#9811)

* add

* fix bug

* fix

* add

* add lora merge

* add

* add

* add

* add

* add

* add

[Unified Checkpoint] Fix expert parallel (PaddlePaddle#9821)

* fix expert parallel

* fix split_param for expert parallel

* add filter_sync_parameters

fix import

[Inference] Flask server compatible with OpenAI api. (PaddlePaddle#9828)

* flask server compatible with OpenAI api.

* fix max_length to max_tokens.

* fix with think model.

[LLM] fix checkpoint save for non flash mode (PaddlePaddle#9830)

support mla for speculate

[DSK] support deepseek-v3/r1 (mha/fp16/bf16/wint8/wint4) (PaddlePaddle#9769)

* support deepseek-v3

* support head_dim=192,256 for append_attn c16

* update 0113

* attention run

* refine code

* add softmax_scale

* support weight_only_int8

* refine code

* support tp

* delete test_append_attn

* add splited fused_moe from ziyuan

* fix repe for deepseek-v3

* add deepseek-v3 class

* fix wint8 precision and refine code

* fix wint4, big diff

* add e_score_correction_bias

* fix head_dim

* fix v3 verify

* fix d2s

* fix v3 verify

* support qk_head_dim != v_head_dim

* fix wint8 v_head_dim

* fix rope

* fix qwen2

* mla use position_ids only

* remove control flow

* remove gpu concat

* fix norm weight dtype

* remove all_reduce in fused_moe

* fix static run

* refine rope code

* compute position_ids use custom op

* fuse rope

* fix macro

* fix mixtral

* support mla for speculate

* fix tokenizer and qwen

* fix moe diff due to e_score_correction_bias

* fix fast tokenizer

* fix import

---------

Co-authored-by: lizhenyun01 <[email protected]>
Co-authored-by: lizhenyun <[email protected]>

Solve the compatibility problem of type annotation Python version (PaddlePaddle#9853)

mix fp8 and wint8

save extra special tokens (PaddlePaddle#9837)

[Bugfix] Fix dsk rope diff (PaddlePaddle#9859)

* fix dsk diff

* fix

* update

merge develop to check fp8 moe-wint8

fix deepseek v3 fp8 precision

fix deepseek weight quant

[Optimization] Support lower memory cards. (PaddlePaddle#9804)

* support lower memory cards.

* add doc for v100 16G such devices.

* remove debug info.

* add pre divided factor to overcome overfit problem for fp16 attention.

Support XPU for auto-paralllel LLaMa (PaddlePaddle#9796)

* Support XPU for auto-paralllel LLaMa

* Update

* Update

* Update

* Update

* Fix CI errors

* Update

[XPU] Add xpu fused op for deepseek (PaddlePaddle#9854)

[Inference] Update deepseek (PaddlePaddle#9864)

* fix

* fix infra

[PreTrain] Support deepseek mfu for pretraining and fix tflops for pretrain pipe model (PaddlePaddle#9855)

* git flops with pp model.

* Support hareware tflops for deepseek.

[Inference]Support mtp with deepseek-v3 (PaddlePaddle#9856)

* support mtp with deepseek_v3 both in static and dygraph mode

* fix speculate tokenizer in unittest

* delete useless code

check code
  • Loading branch information
yuanlehome authored and ckl117 committed Feb 17, 2025
1 parent 2c556e7 commit d982741
Show file tree
Hide file tree
Showing 218 changed files with 15,311 additions and 2,556 deletions.
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
------------------------------------------------------------------------------------------

<p align="center">
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
<a href="https://paddlenlp.readthedocs.io/en/latest/?badge=latest"><img src="https://readthedocs.org/projects/paddlenlp/badge/?version=latest">
<a href="https://github.com/PaddlePaddle/PaddleNLP/releases"><img src="https://img.shields.io/github/v/release/PaddlePaddle/PaddleNLP?color=ffa"></a>
<a href=""><img src="https://img.shields.io/badge/python-3.7+-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-pink.svg"></a>
Expand All @@ -16,6 +16,7 @@
<a href="https://pypi.org/project/paddlenlp/"><img src="https://img.shields.io/pypi/dm/paddlenlp?color=9cf"></a>
<a href="https://github.com/PaddlePaddle/PaddleNLP/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/PaddleNLP?color=9cc"></a>
<a href="https://github.com/PaddlePaddle/PaddleNLP/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/PaddleNLP?color=ccf"></a>
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
</p>

<h4 align="center">
Expand Down Expand Up @@ -69,6 +70,9 @@

大模型套件高性能推理模块内置动态插入和全环节算子融合策略,极大加快并行推理速度。底层实现细节封装化,实现开箱即用的高性能并行推理能力。

## 文档
更多详细文档, 请访问 [PaddleNLP Documentation](https://paddlenlp.readthedocs.io/).

------------------------------------------------------------------------------------------

## 模型支持
Expand All @@ -91,6 +95,7 @@
| [ChatGLM3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm3-6b |
| [DeepSeekV2](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V2, deepseek-ai/DeepSeek-V2-Chat, deepseek-ai/DeepSeek-V2-Lite, deepseek-ai/DeepSeek-V2-Lite-Chat, deepseek-ai/DeepSeek-Coder-V2-Base, deepseek-ai/DeepSeek-Coder-V2-Instruct, deepseek-ai/DeepSeek-Coder-V2-Lite-Base, deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct |
| [DeepSeekV3](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V3, deepseek-ai/DeepSeek-V3-Base |
| [DeepSeek-R1](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-R1, deepseek-ai/DeepSeek-R1-Zero, deepseek-ai/DeepSeek-R1-Distill-Llama-70B, deepseek-ai/DeepSeek-R1-Distill-Llama-8B, deepseek-ai/DeepSeek-R1-Distill-Qwen-14B, deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, deepseek-ai/DeepSeek-R1-Distill-Qwen-32B, deepseek-ai/DeepSeek-R1-Distill-Qwen-7B |
| [Gemma](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/gemma) | google/gemma-7b, google/gemma-7b-it, google/gemma-2b, google/gemma-2b-it |
| [Mistral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mistral) | mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-7B-v0.1 |
| [Mixtral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mixtral) | mistralai/Mixtral-8x7B-Instruct-v0.1 |
Expand Down Expand Up @@ -161,7 +166,7 @@
### 环境依赖

* python >= 3.8
* paddlepaddle >= 3.0.0b0
* paddlepaddle >= 3.0.0rc0

如果您尚未安装 PaddlePaddle,请参考 [飞桨官网](https://www.paddlepaddle.org.cn/) 进行安装。

Expand Down Expand Up @@ -206,7 +211,7 @@ wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwe
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.idx
cd .. # change folder to PaddleNLP/llm
# 如需使用use_fused_rms_norm=true,需要前往slm/model_zoo/gpt-3/external_ops安装fused_ln
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_pretrain.py ./config/llama/pretrain_argument.json --use_fused_rms_norm false
python -u run_pretrain.py ./config/qwen/pretrain_argument_0p5b.json
```

### 大模型 SFT 精调
Expand All @@ -216,7 +221,7 @@ git clone https://github.com/PaddlePaddle/PaddleNLP.git && cd PaddleNLP # 如已
mkdir -p llm/data && cd llm/data
wget https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz && tar -zxvf AdvertiseGen.tar.gz
cd .. # change folder to PaddleNLP/llm
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_finetune.py ./config/llama/sft_argument.json
python -u run_finetune.py ./config/qwen/sft_argument_0p5b.json
```

更多大模型全流程步骤,请参考[飞桨大模型套件](./llm)介绍。
Expand All @@ -231,7 +236,7 @@ dataset = load_dataset("ZHUI/alpaca_demo", split="train")
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT", device="gpu")
trainer = SFTTrainer(
args=training_args,
model="Qwen/Qwen2.5-0.5B",
model="Qwen/Qwen2.5-0.5B-Instruct",
train_dataset=dataset,
)
trainer.train()
Expand Down
8 changes: 6 additions & 2 deletions README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
------------------------------------------------------------------------------------------

<p align="center">
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
<a href="https://paddlenlp.readthedocs.io/en/latest/?badge=latest"><img src="https://readthedocs.org/projects/paddlenlp/badge/?version=latest">
<a href="https://github.com/PaddlePaddle/PaddleNLP/releases"><img src="https://img.shields.io/github/v/release/PaddlePaddle/PaddleNLP?color=ffa"></a>
<a href=""><img src="https://img.shields.io/badge/python-3.7+-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-pink.svg"></a>
Expand All @@ -16,6 +16,7 @@
<a href="https://pypi.org/project/paddlenlp/"><img src="https://img.shields.io/pypi/dm/paddlenlp?color=9cf"></a>
<a href="https://github.com/PaddlePaddle/PaddleNLP/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/PaddleNLP?color=9cc"></a>
<a href="https://github.com/PaddlePaddle/PaddleNLP/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/PaddleNLP?color=ccf"></a>
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
</p>

<h4 align="center">
Expand Down Expand Up @@ -52,6 +53,9 @@ The fine-tuning algorithms are deeply integrated with zero-padding data streams

The high-performance inference module of the large model toolkit incorporates dynamic insertion and operator fusion strategies throughout the entire process, greatly accelerating parallel inference speed. The underlying implementation details are encapsulated, enabling out-of-the-box high-performance parallel inference capabilities.

## Documentation
For detailed documentation, visit the [PaddleNLP Documentation](https://paddlenlp.readthedocs.io/).

------------------------------------------------------------------------------------------

## Support Models
Expand All @@ -68,7 +72,7 @@ Detailed list 👉 [Supported Model List](https://github.com/PaddlePaddle/Paddle
### Pip Installation

```shell
pip install --upgrade paddlenlp==3.0.0b2
pip install --upgrade paddlenlp==3.0.0b3
```

or you can install the latest develop branch code with the following command:
Expand Down
7 changes: 5 additions & 2 deletions csrc/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# PaddleNLP 自定义 OP
# PaddleNLP 大模型高性能自定义推理算子

此文档介绍如何编译安装 PaddleNLP 自定义 OP。
此文档介绍如何编译安装 PaddleNLP 大模型高性能自定义推理算子的安装教程。

使用这些高性能算子,可以大幅提升大模型推理速度。
大模型推理相关教程详见[此处](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/README.md#6-%E6%8E%A8%E7%90%86)

## 安装 C++ 依赖

Expand Down
37 changes: 27 additions & 10 deletions csrc/gpu/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
Expand Down Expand Up @@ -97,21 +98,21 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
paddle::DataType::INT8,
qkv.place());
}
else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
}
} else {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
D,
qkv.place());
}
Expand Down Expand Up @@ -203,6 +204,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -240,6 +242,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -282,6 +285,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -428,6 +432,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -465,6 +470,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -508,6 +514,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -565,6 +572,7 @@ std::vector<paddle::Tensor> AppendAttention(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
Expand All @@ -578,9 +586,10 @@ std::vector<paddle::Tensor> AppendAttention(
meta_data.token_nums = qkv_dims[0];
meta_data.kv_num_heads = key_cache_dims[1];
meta_data.head_dims = key_cache_dims[3];
const int total_num_head =
qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims;
meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads;
meta_data.head_dims_v = value_cache.dims()[3];
const int q_hidden_size =
qkv_dims[qkv_dims.size() - 1] - meta_data.kv_num_heads * (meta_data.head_dims + meta_data.head_dims_v);
meta_data.q_num_heads = q_hidden_size / meta_data.head_dims;

meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
Expand Down Expand Up @@ -626,6 +635,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -672,6 +682,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -719,6 +730,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -764,6 +776,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -821,10 +834,12 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape) {
const int token_num = qkv_shape[0];
const int kv_num_heads = key_cache_shape[1];
const int head_dim = key_cache_shape[3];
const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
const int num_heads = total_num_head - 2 * kv_num_heads;
return {{token_num, num_heads * head_dim}, qkv_shape};
const int head_dim_qk = key_cache_shape[3];
const int head_dim_v = value_cache_shape[3];
const int q_hidden_size =
qkv_shape[qkv_shape.size() - 1] - kv_num_heads * (head_dim_qk + head_dim_v);
const int num_heads = q_hidden_size / head_dim_qk;
return {{token_num, num_heads * head_dim_v}, qkv_shape};
}

std::vector<paddle::DataType> AppendAttentionInferDtype(
Expand Down Expand Up @@ -865,6 +880,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
Expand Down Expand Up @@ -941,6 +957,7 @@ PD_BUILD_OP(append_attention)
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"max_input_length: int",
"softmax_scale: float",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
Expand Down
Loading

0 comments on commit d982741

Please sign in to comment.