Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix write cache #15

Merged
merged 1 commit into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions csrc/gpu/append_attn/decode_attention_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ __global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [toke
}


template <typename T, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V, uint32_t BLOCK_SIZE, bool CAUSAL, uint32_t NUM_STAGE, CacheType cache_type, uint32_t cache_bytes, uint32_t DEAL_EACH_TIME, uint32_t NUM_THREADS, PosEncMode pos_enc_mode = PosEncMode::kNonePos>
template <typename T, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V, uint32_t BLOCK_SIZE, bool CAUSAL, uint32_t NUM_STAGE, CacheType cache_type, uint32_t cache_bytes, uint32_t DEAL_EACH_TIME, PosEncMode pos_enc_mode = PosEncMode::kNonePos>
void MultiQueryDecoderAttention(
const AppendAttnMetaData& meta_data,
cudaStream_t &stream,
Expand Down Expand Up @@ -872,12 +872,10 @@ void DecodeMLAAttentionKernel(
{DISPATCH_HEAD_DIM(head_dim_qk, HEAD_DIM_QK,
{DISPATCH_HEAD_DIM(head_dim_v, HEAD_DIM_V,
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
{DISPATCH_NUM_STAGE(num_stage, NUM_STAGE,
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
{DISPATCH_NUM_THREADS(num_threads, NUM_THREADS,
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, NUM_STAGE, CacheType::CacheT, 16, DEAL_EACH_TIME, NUM_THREADS>(
{MultiQueryDecoderAttention<T, 128, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, CacheType::CacheT, 16, DEAL_EACH_TIME>(
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cum_offsets,
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})})})});
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
}

template void DecodeMLAAttentionKernel<paddle::bfloat16>(
Expand Down
2 changes: 2 additions & 0 deletions csrc/gpu/append_attn/mla_cache_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1];
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;

Expand Down Expand Up @@ -191,6 +192,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1];
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;

Expand Down
Loading