Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DSK] Implement mla use matrix-absorption #9875

Open
wants to merge 39 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7ec7f02
mla python part
yuanlehome Feb 17, 2025
1e05b16
absorb mla optimizer
lizhenyun01 Feb 17, 2025
3e1fbe5
Merge pull request #13 from lizhenyun01/deepseek_mla
yuanlehome Feb 17, 2025
22e4e3d
fix c8/c4 dtype in mla
lizhenyun01 Feb 17, 2025
90a158a
Merge pull request #14 from lizhenyun01/deepseek_mla
yuanlehome Feb 17, 2025
1bd78d0
add weight_only part 1
yuanlehome Feb 17, 2025
0aab4a5
dy can run
yuanlehome Feb 17, 2025
1bb8ef8
static can run
yuanlehome Feb 17, 2025
2d68bea
nothing
yuanlehome Feb 17, 2025
ebf1a76
refine network
yuanlehome Feb 18, 2025
871f2e6
fix write cache
lizhenyun01 Feb 18, 2025
8a1e982
Merge pull request #15 from lizhenyun01/deepseek_mla
yuanlehome Feb 18, 2025
549f709
update
yuanlehome Feb 18, 2025
1165609
Merge branch 'deepseek-v3-mla' of https://github.com/yuanlehome/Paddl…
yuanlehome Feb 18, 2025
7dc8c55
add pd_throw
yuanlehome Feb 18, 2025
bbd0051
add pd_throw
yuanlehome Feb 18, 2025
eabd751
fix mla_atn
lizhenyun01 Feb 18, 2025
820bb38
Merge branch 'deepseek-v3-mla' into deepseek_mla
yuanlehome Feb 18, 2025
71e4a7d
Merge pull request #16 from lizhenyun01/deepseek_mla
yuanlehome Feb 18, 2025
d0f40f6
fix mla
lizhenyun01 Feb 19, 2025
657d67d
Merge pull request #17 from lizhenyun01/deepseek_mla
yuanlehome Feb 19, 2025
58a020b
update network
yuanlehome Feb 19, 2025
d060c98
weight only support group wise
yuanlehome Feb 19, 2025
a11bb32
fix MLA
Feb 20, 2025
1ecbb23
Merge pull request #18 from lizhenyun01/deepseek-v3-mla
yuanlehome Feb 20, 2025
2a96da0
update split kv_b
yuanlehome Feb 20, 2025
5e44eeb
fix
yuanlehome Feb 20, 2025
4ed7180
refine if
yuanlehome Feb 21, 2025
184765e
half support new absorb
yuanlehome Feb 21, 2025
9e2ea0e
weight only support new absorb
yuanlehome Feb 21, 2025
4d90d61
fix
yuanlehome Feb 21, 2025
4f1d25c
fix bf16
yuanlehome Feb 22, 2025
8673095
optimize mla
lizhenyun01 Feb 22, 2025
680ed55
Merge pull request #19 from lizhenyun01/deepseek-v3-mla
yuanlehome Feb 22, 2025
5fcaf18
set kv_cache's bsz=1
lizhenyun01 Feb 22, 2025
b425f74
Merge pull request #20 from lizhenyun01/deepseek-v3-mla
yuanlehome Feb 22, 2025
d824c2a
delete max_batch_size
yuanlehome Feb 22, 2025
3102788
refine if
yuanlehome Feb 24, 2025
2fb3378
not_need_stop to cpu
yuanlehome Feb 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions csrc/gpu/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const float out_linear_in_scale,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const bool mla_use_absorb) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
Expand Down Expand Up @@ -144,6 +145,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
kv_num_blocks_data,
max_input_length,
use_neox_rotary_style,
mla_use_absorb,
main_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
Expand Down Expand Up @@ -171,6 +173,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
kv_num_blocks_data,
max_input_length,
use_neox_rotary_style,
mla_use_absorb,
main_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
Expand Down Expand Up @@ -212,6 +215,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
causal,
false,
true,
mla_use_absorb,
main_stream,
&fmha_out);
break;
Expand Down Expand Up @@ -250,6 +254,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
causal,
false,
true,
mla_use_absorb,
main_stream,
&fmha_out);
break;
Expand Down Expand Up @@ -293,12 +298,13 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
causal,
false,
true,
mla_use_absorb,
main_stream,
&fmha_out);
}
}

if (max_dec_len_this_time_data > 0) {
if (!mla_use_absorb && max_dec_len_this_time_data > 0) {
cudaStream_t exec_stream;
if (max_enc_len_this_time_data > 0) {
cudaStreamWaitEvent(decoder_stream, main_event);
Expand Down Expand Up @@ -440,6 +446,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
causal,
!speculate_decoder,
!speculate_decoder,
mla_use_absorb,
exec_stream,
&fmha_out);
break;
Expand Down Expand Up @@ -478,6 +485,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
causal,
!speculate_decoder,
!speculate_decoder,
mla_use_absorb,
exec_stream,
&fmha_out);
break;
Expand Down Expand Up @@ -522,6 +530,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
causal,
!speculate_decoder,
!speculate_decoder,
mla_use_absorb,
exec_stream,
&fmha_out);
}
Expand Down Expand Up @@ -578,7 +587,8 @@ std::vector<paddle::Tensor> AppendAttention(
const float out_linear_in_scale,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const bool mla_use_absorb) {
AppendAttnMetaData meta_data;

const auto& qkv_dims = qkv.dims();
Expand Down Expand Up @@ -641,7 +651,8 @@ std::vector<paddle::Tensor> AppendAttention(
out_linear_in_scale,
speculate_max_draft_token_num,
causal,
speculate_decoder);
speculate_decoder,
mla_use_absorb);
}
case paddle::DataType::BFLOAT16: {
return AppendAttentionKernel<paddle::DataType::BFLOAT16>(
Expand Down Expand Up @@ -688,7 +699,8 @@ std::vector<paddle::Tensor> AppendAttention(
out_linear_in_scale,
speculate_max_draft_token_num,
causal,
speculate_decoder);
speculate_decoder,
mla_use_absorb);
}
case paddle::DataType::INT32: {
if (compute_dtype == "bf16") {
Expand Down Expand Up @@ -736,7 +748,8 @@ std::vector<paddle::Tensor> AppendAttention(
out_linear_in_scale,
speculate_max_draft_token_num,
causal,
speculate_decoder);
speculate_decoder,
mla_use_absorb);
} else if (compute_dtype == "fp16") {
return AppendAttentionKernel<paddle::DataType::FLOAT16>(
meta_data,
Expand Down Expand Up @@ -782,7 +795,8 @@ std::vector<paddle::Tensor> AppendAttention(
out_linear_in_scale,
speculate_max_draft_token_num,
causal,
speculate_decoder);
speculate_decoder,
mla_use_absorb);
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
break;
Expand Down Expand Up @@ -886,7 +900,8 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const float out_linear_in_scale,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const bool mla_use_absorb) {
if (compute_dtype == "bf16") {
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
Expand Down Expand Up @@ -963,7 +978,8 @@ PD_BUILD_OP(append_attention)
"out_linear_in_scale: float",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool"})
"speculate_decoder: bool",
"mla_use_absorb: bool"})
.SetKernelFn(PD_KERNEL(AppendAttention))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));
16 changes: 11 additions & 5 deletions csrc/gpu/append_attn/append_attention_c16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ __global__ void multi_query_append_attention_kernel(
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
const int speculate_max_draft_token_num = 5,
const bool mla_use_absorb = false) {
constexpr uint32_t num_vecs_per_head_qk =
HEAD_DIM_QK / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_v = HEAD_DIM_V / num_elems_per_128b<T>();
Expand Down Expand Up @@ -221,7 +222,7 @@ __global__ void multi_query_append_attention_kernel(
wid * 4 + tid / 8, tid % 8);

uint32_t kv_idx_base = chunk_start;
int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]);
int block_id = mla_use_absorb ? kv_idx_base / BLOCK_SIZE : __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]);
const uint32_t const_offset_k = kv_head_idx * k_h_stride +
(wid * 4 + tid / 8) * k_b_stride +
tid % 8 * num_elems_per_128b<T>();
Expand Down Expand Up @@ -327,7 +328,7 @@ __global__ void multi_query_append_attention_kernel(
__syncthreads();

kv_idx_base += num_frags_z * 16;
block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]);
block_id = mla_use_absorb ? kv_idx_base / BLOCK_SIZE : __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]);
if (block_id < 0) {
block_id = 0;
}
Expand Down Expand Up @@ -1023,6 +1024,7 @@ void MultiQueryAppendAttention(
const float in_scale,
const int speculate_max_draft_token_num,
const bool is_decoder,
const bool mla_use_absorb,
cudaStream_t &stream,
paddle::Tensor *out) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
Expand Down Expand Up @@ -1133,7 +1135,8 @@ void MultiQueryAppendAttention(
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
mla_use_absorb);

} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
Expand Down Expand Up @@ -1191,7 +1194,8 @@ void MultiQueryAppendAttention(
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
mla_use_absorb);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
Expand Down Expand Up @@ -1549,6 +1553,7 @@ void CascadeAppendAttentionC16Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const bool mla_use_absorb,
cudaStream_t &stream,
paddle::Tensor *out) {
const auto token_num = meta_data.token_nums;
Expand Down Expand Up @@ -1613,6 +1618,7 @@ void CascadeAppendAttentionC16Kernel(
in_scale,
speculate_max_draft_token_num,
is_decoder,
mla_use_absorb,
stream,
out);
})})})})})})})
Expand Down
3 changes: 3 additions & 0 deletions csrc/gpu/append_attn/append_attention_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ void CascadeAppendAttentionC16Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const bool mla_use_absorb,
cudaStream_t& stream,
paddle::Tensor* out);

Expand Down Expand Up @@ -190,6 +191,7 @@ void CascadeAppendAttentionKernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const bool mla_use_absorb,
cudaStream_t& stream,
paddle::Tensor* out) {
if (cache_quant_type_str == "none") {
Expand Down Expand Up @@ -224,6 +226,7 @@ void CascadeAppendAttentionKernel(
causal,
is_decoder,
enable_prefill,
mla_use_absorb,
stream,
out);
} else if (cache_quant_type_str == "cache_int8") {
Expand Down
Loading
Loading