Skip to content

Commit

Permalink
prefill use flash_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed Feb 24, 2025
1 parent 2fb3378 commit 0e31296
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 174 deletions.
8 changes: 0 additions & 8 deletions csrc/gpu/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
kv_num_blocks_data,
max_input_length,
use_neox_rotary_style,
mla_use_absorb,
main_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
Expand Down Expand Up @@ -173,7 +172,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
kv_num_blocks_data,
max_input_length,
use_neox_rotary_style,
mla_use_absorb,
main_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
Expand Down Expand Up @@ -215,7 +213,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
causal,
false,
true,
mla_use_absorb,
main_stream,
&fmha_out);
break;
Expand Down Expand Up @@ -254,7 +251,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
causal,
false,
true,
mla_use_absorb,
main_stream,
&fmha_out);
break;
Expand Down Expand Up @@ -298,7 +294,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
causal,
false,
true,
mla_use_absorb,
main_stream,
&fmha_out);
}
Expand Down Expand Up @@ -446,7 +441,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
causal,
!speculate_decoder,
!speculate_decoder,
mla_use_absorb,
exec_stream,
&fmha_out);
break;
Expand Down Expand Up @@ -485,7 +479,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
causal,
!speculate_decoder,
!speculate_decoder,
mla_use_absorb,
exec_stream,
&fmha_out);
break;
Expand Down Expand Up @@ -530,7 +523,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
causal,
!speculate_decoder,
!speculate_decoder,
mla_use_absorb,
exec_stream,
&fmha_out);
}
Expand Down
16 changes: 5 additions & 11 deletions csrc/gpu/append_attn/append_attention_c16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ __global__ void multi_query_append_attention_kernel(
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const bool mla_use_absorb = false) {
const int speculate_max_draft_token_num = 5) {
constexpr uint32_t num_vecs_per_head_qk =
HEAD_DIM_QK / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_v = HEAD_DIM_V / num_elems_per_128b<T>();
Expand Down Expand Up @@ -222,7 +221,7 @@ __global__ void multi_query_append_attention_kernel(
wid * 4 + tid / 8, tid % 8);

uint32_t kv_idx_base = chunk_start;
int block_id = mla_use_absorb ? kv_idx_base / BLOCK_SIZE : __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]);
int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]);
const uint32_t const_offset_k = kv_head_idx * k_h_stride +
(wid * 4 + tid / 8) * k_b_stride +
tid % 8 * num_elems_per_128b<T>();
Expand Down Expand Up @@ -328,7 +327,7 @@ __global__ void multi_query_append_attention_kernel(
__syncthreads();

kv_idx_base += num_frags_z * 16;
block_id = mla_use_absorb ? kv_idx_base / BLOCK_SIZE : __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]);
block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]);
if (block_id < 0) {
block_id = 0;
}
Expand Down Expand Up @@ -1024,7 +1023,6 @@ void MultiQueryAppendAttention(
const float in_scale,
const int speculate_max_draft_token_num,
const bool is_decoder,
const bool mla_use_absorb,
cudaStream_t &stream,
paddle::Tensor *out) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
Expand Down Expand Up @@ -1135,8 +1133,7 @@ void MultiQueryAppendAttention(
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
mla_use_absorb);
speculate_max_draft_token_num);

} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
Expand Down Expand Up @@ -1194,8 +1191,7 @@ void MultiQueryAppendAttention(
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
mla_use_absorb);
speculate_max_draft_token_num);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
Expand Down Expand Up @@ -1553,7 +1549,6 @@ void CascadeAppendAttentionC16Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const bool mla_use_absorb,
cudaStream_t &stream,
paddle::Tensor *out) {
const auto token_num = meta_data.token_nums;
Expand Down Expand Up @@ -1618,7 +1613,6 @@ void CascadeAppendAttentionC16Kernel(
in_scale,
speculate_max_draft_token_num,
is_decoder,
mla_use_absorb,
stream,
out);
})})})})})})})
Expand Down
3 changes: 0 additions & 3 deletions csrc/gpu/append_attn/append_attention_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ void CascadeAppendAttentionC16Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const bool mla_use_absorb,
cudaStream_t& stream,
paddle::Tensor* out);

Expand Down Expand Up @@ -191,7 +190,6 @@ void CascadeAppendAttentionKernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const bool mla_use_absorb,
cudaStream_t& stream,
paddle::Tensor* out) {
if (cache_quant_type_str == "none") {
Expand Down Expand Up @@ -226,7 +224,6 @@ void CascadeAppendAttentionKernel(
causal,
is_decoder,
enable_prefill,
mla_use_absorb,
stream,
out);
} else if (cache_quant_type_str == "cache_int8") {
Expand Down
9 changes: 3 additions & 6 deletions csrc/gpu/append_attn/encoder_write_cache_with_rope_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -636,8 +636,7 @@ __global__ void cache_kernel(
const int head_size_v,
const int block_size,
const uint32_t elem_cnt,
const int kv_num_heads,
const bool mla_use_absorb = false) {
const int kv_num_heads) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;

Expand All @@ -661,7 +660,7 @@ __global__ void cache_kernel(

block_table_now = block_tables + ori_bi * max_blocks_per_seq;

const uint32_t block_idx = mla_use_absorb ? ori_seq_id / block_size : block_table_now[ori_seq_id / block_size];
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
const uint32_t block_offset = ori_seq_id % block_size;

if (bias < hidden_size_k) {
Expand Down Expand Up @@ -1467,7 +1466,6 @@ void CascadeAppendWriteCacheKVQKV(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const int max_seq_len,
const bool mla_use_absorb,
cudaStream_t &stream,
paddle::Tensor *key_cache_out,
paddle::Tensor *value_cache_out) {
Expand Down Expand Up @@ -1501,8 +1499,7 @@ void CascadeAppendWriteCacheKVQKV(
head_dim_v,
block_size,
elem_nums,
kv_num_heads,
mla_use_absorb);
kv_num_heads);
}

template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
Expand Down
2 changes: 0 additions & 2 deletions csrc/gpu/append_attn/encoder_write_cache_with_rope_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ void EncoderWriteCacheWithRopeKernel(
const int num_blocks,
const int max_seq_len,
const bool use_neox_style,
const bool mla_use_absorb,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
Expand Down Expand Up @@ -97,7 +96,6 @@ void EncoderWriteCacheWithRopeKernel(
seq_lens_encoder,
seq_lens_decoder,
max_seq_len,
mla_use_absorb,
stream,
key_cache_out,
value_cache_out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,5 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const bool mla_use_absorb,
cudaStream_t& stream,
paddle::Tensor* out);
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
const int num_blocks,
const int max_seq_len,
const bool use_neox_style,
const bool mla_use_absorb,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
Expand Down
27 changes: 3 additions & 24 deletions paddlenlp/experimental/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
self.config = config

self.max_seq_len = config.max_seq_len
self.block_size = config.block_size

self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -483,32 +484,8 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
for idx in range(self.num_layers)
]

self.prefill_cache_k_buffer: paddle.Tensor = None
self.prefill_cache_v_buffer: paddle.Tensor = None
if self.config.mla_use_matrix_absorption:
max_batch_size = 1
max_block_nums = max_batch_size * (self.max_seq_len + config.block_size - 1) // config.block_size
cache_k_shape = [
max_block_nums,
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
config.block_size,
config.qk_nope_head_dim + config.qk_rope_head_dim,
]
cache_v_shape = [
max_block_nums,
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
config.block_size,
config.v_head_dim,
]
self.prefill_cache_k_buffer = paddle.empty(shape=cache_k_shape, dtype=paddle.get_default_dtype())
self.prefill_cache_v_buffer = paddle.empty(shape=cache_v_shape, dtype=paddle.get_default_dtype())
self.register_buffer("prefill_cache_k_buffer", self.prefill_cache_k_buffer, persistable=False)
self.register_buffer("prefill_cache_v_buffer", self.prefill_cache_v_buffer, persistable=False)

mla_config = MLAConfig(
use_matrix_absorption=self.config.mla_use_matrix_absorption,
prefill_cache_k_buffer=self.prefill_cache_k_buffer,
prefill_cache_v_buffer=self.prefill_cache_v_buffer,
q_lora_rank=self.config.q_lora_rank,
kv_lora_rank=self.config.kv_lora_rank,
qk_nope_head_dim=self.config.qk_nope_head_dim,
Expand Down Expand Up @@ -943,6 +920,7 @@ def forward(
kwargs["cu_seqlens_k"] = cu_seqlens_k
kwargs["padding_offsets"] = padding_offset
kwargs["max_input_length"] = self.max_seq_len
kwargs["block_size"] = self.block_size

inputs_embeds = self.embed_tokens(ids_remove_padding)

Expand Down Expand Up @@ -1010,6 +988,7 @@ def forward(
kwargs["cu_seqlens_k"] = cu_seqlens_k
kwargs["padding_offsets"] = padding_offset
kwargs["max_input_length"] = self.max_seq_len
kwargs["block_size"] = self.block_size

inputs_embeds = self.embed_tokens(ids_remove_padding)
inputs_embeds = paddle.concat([self.enorm(inputs_embeds), self.hnorm(pre_hidden_states)], axis=-1)
Expand Down
Loading

0 comments on commit 0e31296

Please sign in to comment.