Skip to content

Commit f07e997

Browse files
authored
[None] [feat] Use triton kernels for RocketKV prediction module (#8682)
Signed-off-by: yuhangh <[email protected]>
1 parent cc4c980 commit f07e997

File tree

17 files changed

+3532
-685
lines changed

17 files changed

+3532
-685
lines changed

cpp/tensorrt_llm/kernels/sparseAttentionKernels.cu

Lines changed: 108 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace tensorrt_llm
2020
{
2121
namespace kernels
2222
{
23-
template <int THREADS_PER_BLOCK>
23+
template <int THREADS_PER_BLOCK, int MAX_NUM_PAGES>
2424
__global__ void gatherKvPageOffsetsKernel(
2525
int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
2626
int32_t* output_seq_lengths, // [num_head_kv, batch_size]
@@ -32,23 +32,33 @@ __global__ void gatherKvPageOffsetsKernel(
3232
// Each CUDA block processes one sequence from the batch for one head.
3333
int32_t const head_idx = blockIdx.x;
3434
int32_t const batch_idx = blockIdx.y;
35+
int32_t const indices_block_size = sparse_params.sparse_attn_indices_block_size;
3536
if (batch_idx >= batch_size)
3637
{
3738
return;
3839
}
3940

40-
// Shared memory for reduction.
41-
__shared__ typename cub::BlockReduce<Pair, THREADS_PER_BLOCK>::TempStorage temp_storage;
41+
using BlockScan = cub::BlockScan<int32_t, THREADS_PER_BLOCK>;
42+
using BlockReduce = cub::BlockReduce<Pair, THREADS_PER_BLOCK>;
43+
44+
__shared__ typename BlockScan::TempStorage temp_storage_scan;
45+
__shared__ typename BlockReduce::TempStorage temp_storage_reduce;
46+
47+
__shared__ int32_t s_page_mask[MAX_NUM_PAGES];
48+
__shared__ int32_t s_cu_page_mask[MAX_NUM_PAGES];
49+
__shared__ int32_t s_scan_total; // Store total count from scan
4250

4351
// Get the range of sparse indices and the sequence length.
4452
int32_t const start_offset = sparse_params.sparse_attn_offsets[batch_idx];
4553
int32_t const end_offset = sparse_params.sparse_attn_offsets[batch_idx + 1];
46-
int32_t const total_pages = sparse_params.sparse_attn_offsets[batch_size];
47-
int32_t const num_sparse_pages = end_offset - start_offset;
54+
int32_t const sparse_attn_indices_stride = sparse_params.sparse_attn_indices_stride;
55+
int32_t const num_sparse_indices = end_offset - start_offset;
4856
int32_t const original_seq_len = seq_lengths[batch_idx];
57+
int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page;
58+
int32_t const page_loops = (ori_valid_pages + MAX_NUM_PAGES - 1) / MAX_NUM_PAGES;
4959

5060
// Get global sparse index.
51-
int32_t const sparse_idx_global = head_idx * total_pages + start_offset;
61+
int32_t const sparse_idx_global = head_idx * sparse_attn_indices_stride + start_offset;
5262

5363
// Get the base memory offset. shape: [batch_size, 2, max_num_pages_per_seq]
5464
size_t const src_base_offset = (size_t) batch_idx * 2 * max_num_pages_per_seq;
@@ -58,56 +68,119 @@ __global__ void gatherKvPageOffsetsKernel(
5868
int32_t local_max_page_index = -1;
5969
int32_t local_num_valid_pages = 0;
6070

61-
// Perform the gather operation.
62-
for (int32_t i = threadIdx.x; i < num_sparse_pages; i += blockDim.x)
71+
int32_t src_page_idx_offset = 0;
72+
int32_t dst_page_idx_offset = 0;
73+
for (int32_t loop_idx = 0; loop_idx < page_loops; loop_idx++)
6374
{
64-
// Get the source idx and offset.
65-
int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i];
66-
if (src_idx < 0)
75+
src_page_idx_offset = loop_idx * MAX_NUM_PAGES;
76+
int32_t loop_num_valid_pages = min(MAX_NUM_PAGES, ori_valid_pages - src_page_idx_offset);
77+
for (int32_t i = threadIdx.x; i < MAX_NUM_PAGES; i += blockDim.x)
78+
{
79+
s_page_mask[i] = 0;
80+
}
81+
__syncthreads();
82+
83+
for (int32_t i = threadIdx.x; i < num_sparse_indices; i += blockDim.x)
6784
{
68-
continue;
85+
int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i];
86+
int32_t const src_idx_start = src_idx * indices_block_size;
87+
int32_t const src_idx_end = min(src_idx_start + indices_block_size, original_seq_len);
88+
for (int32_t j = src_idx_start; j < src_idx_end; j++)
89+
{
90+
int32_t const src_page_idx = j / tokens_per_page;
91+
if (src_page_idx >= src_page_idx_offset && src_page_idx < src_page_idx_offset + loop_num_valid_pages)
92+
{
93+
atomicExch(&s_page_mask[src_page_idx - src_page_idx_offset], 1);
94+
}
95+
}
6996
}
97+
__syncthreads();
98+
99+
// Handle case when loop_num_valid_pages > blockDim.x by processing in chunks
100+
int32_t scan_offset = 0;
101+
int32_t const scan_chunks = (loop_num_valid_pages + blockDim.x - 1) / blockDim.x;
70102

71-
// Update the local max page index.
72-
local_max_page_index = max(local_max_page_index, src_idx);
73-
local_num_valid_pages++;
103+
for (int32_t chunk_idx = 0; chunk_idx < scan_chunks; chunk_idx++)
104+
{
105+
int32_t const chunk_start = chunk_idx * blockDim.x;
106+
int32_t const chunk_size = min((int32_t) blockDim.x, loop_num_valid_pages - chunk_start);
107+
108+
int32_t thread_data = (threadIdx.x < chunk_size) ? s_page_mask[chunk_start + threadIdx.x] : 0;
109+
int32_t thread_output;
110+
int32_t aggregate;
111+
112+
BlockScan(temp_storage_scan).ExclusiveSum(thread_data, thread_output, aggregate);
113+
__syncthreads();
114+
115+
if (threadIdx.x < chunk_size)
116+
{
117+
s_cu_page_mask[chunk_start + threadIdx.x] = thread_output + scan_offset;
118+
}
119+
__syncthreads();
74120

75-
// Get the source and destination offsets.
76-
size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
77-
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
78-
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + i;
79-
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + i;
121+
// Update scan offset for next chunk
122+
scan_offset += aggregate;
123+
}
80124

81-
// Perform the gather operation: read from the sparse location and write to the dense location.
82-
output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
83-
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
125+
if (threadIdx.x == 0)
126+
{
127+
s_scan_total = scan_offset;
128+
}
129+
__syncthreads();
130+
131+
// Perform the gather operation.
132+
for (int32_t i = threadIdx.x; i < loop_num_valid_pages; i += blockDim.x)
133+
{
134+
// Skip if the page is not valid.
135+
if (s_page_mask[i] == 0)
136+
{
137+
continue;
138+
}
139+
140+
int32_t const src_idx = src_page_idx_offset + i;
141+
int32_t const dst_idx = dst_page_idx_offset + s_cu_page_mask[i];
142+
143+
local_max_page_index = max(local_max_page_index, src_idx);
144+
local_num_valid_pages++;
145+
146+
size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
147+
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
148+
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + dst_idx;
149+
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + dst_idx;
150+
151+
output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
152+
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
153+
}
154+
__syncthreads();
155+
156+
// Update dst offset using the total count from scan
157+
dst_page_idx_offset += s_scan_total;
84158
}
85159

86160
// Reduce the local max page indices and number of valid pages.
87161
Pair local_pair = {local_max_page_index, local_num_valid_pages};
88-
Pair result = cub::BlockReduce<Pair, THREADS_PER_BLOCK>(temp_storage).Reduce(local_pair, PairReduceOp());
162+
Pair result = BlockReduce(temp_storage_reduce).Reduce(local_pair, PairReduceOp());
89163

90164
// Update sequence length for this head and batch.
91165
if (threadIdx.x == 0)
92166
{
93167
int32_t const max_page_index = result.max_val;
94168
int32_t const num_valid_pages = result.sum_val;
95-
int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page;
96169
size_t const seq_len_offset = (size_t) head_idx * batch_size + batch_idx;
170+
int32_t seq_len = 0;
97171
if (num_valid_pages > 0)
98172
{
99-
int32_t seq_len = original_seq_len - (ori_valid_pages - num_valid_pages) * tokens_per_page;
100-
int32_t seq_len_remain = original_seq_len % tokens_per_page;
101-
if (max_page_index != ori_valid_pages - 1 && seq_len_remain != 0)
173+
if (max_page_index == ori_valid_pages - 1)
102174
{
103-
seq_len += tokens_per_page - seq_len_remain;
175+
seq_len = (num_valid_pages - 1) * tokens_per_page
176+
+ (original_seq_len - (ori_valid_pages - 1) * tokens_per_page);
177+
}
178+
else
179+
{
180+
seq_len = num_valid_pages * tokens_per_page;
104181
}
105-
output_seq_lengths[seq_len_offset] = seq_len;
106-
}
107-
else
108-
{
109-
output_seq_lengths[seq_len_offset] = 0;
110182
}
183+
output_seq_lengths[seq_len_offset] = seq_len;
111184
}
112185
}
113186

@@ -121,11 +194,8 @@ void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, int32_t* output_
121194
dim3 grid(num_head_kv, batch_size, 1);
122195
// The block.
123196
dim3 block(256, 1, 1);
124-
// Shared memory size.
125-
size_t smem_size = sizeof(Pair) * 256;
126197

127-
// Launch the kernel.
128-
gatherKvPageOffsetsKernel<256><<<grid, block, smem_size, stream>>>(output_kv_page_offsets, output_seq_lengths,
198+
gatherKvPageOffsetsKernel<256, 512><<<grid, block, 0, stream>>>(output_kv_page_offsets, output_seq_lengths,
129199
kv_page_offsets, seq_lengths, sparse_params, batch_size, tokens_per_page, max_num_pages_per_seq);
130200
}
131201
} // namespace kernels

cpp/tensorrt_llm/kernels/sparseAttentionKernels.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ struct SparseAttentionParams
3535
int32_t sparse_mla_topk{0}; // for DSA attention
3636
void* sparse_mla_kv_cache_pool{nullptr}; // for DSA attention
3737

38+
int32_t sparse_attn_indices_block_size{1};
39+
int32_t sparse_attn_indices_stride{0};
40+
3841
std::string toString() const
3942
{
4043
std::stringstream ss;
@@ -43,7 +46,9 @@ struct SparseAttentionParams
4346
<< "sparse_kv_offsets: " << this->sparse_kv_offsets << std::endl
4447
<< "sparse_attn_offsets: " << this->sparse_attn_offsets << std::endl
4548
<< "sparse_mla_topk: " << this->sparse_mla_topk << std::endl
46-
<< "sparse_mla_kv_cache_pool: " << this->sparse_mla_kv_cache_pool << std::endl;
49+
<< "sparse_mla_kv_cache_pool: " << this->sparse_mla_kv_cache_pool << std::endl
50+
<< "sparse_attn_indices_block_size: " << this->sparse_attn_indices_block_size << std::endl
51+
<< "sparse_attn_indices_stride: " << this->sparse_attn_indices_stride << std::endl;
4752
return ss.str();
4853
}
4954
};

cpp/tensorrt_llm/nanobind/thop/bindings.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ void initBindings(nb::module_& m)
6464
nb::arg("softmax_stats_tensor") = std::nullopt, nb::arg("spec_decoding_bool_params"),
6565
nb::arg("spec_decoding_tensor_params"), nb::arg("sparse_kv_indices") = std::nullopt,
6666
nb::arg("sparse_kv_offsets") = std::nullopt, nb::arg("sparse_attn_indices") = std::nullopt,
67-
nb::arg("sparse_attn_offsets") = std::nullopt, nb::arg("sparse_mla_topk") = std::nullopt,
68-
nb::arg("cu_q_seqlens") = std::nullopt, nb::arg("cu_kv_seqlens") = std::nullopt,
69-
nb::arg("fmha_scheduler_counter") = std::nullopt, nb::arg("mla_bmm1_scale") = std::nullopt,
70-
nb::arg("mla_bmm2_scale") = std::nullopt, nb::arg("quant_q_buffer") = std::nullopt,
71-
"Multi-head attention operation", nb::call_guard<nb::gil_scoped_release>());
67+
nb::arg("sparse_attn_offsets") = std::nullopt, nb::arg("sparse_attn_indices_block_size"),
68+
nb::arg("sparse_mla_topk") = std::nullopt, nb::arg("cu_q_seqlens") = std::nullopt,
69+
nb::arg("cu_kv_seqlens") = std::nullopt, nb::arg("fmha_scheduler_counter") = std::nullopt,
70+
nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt,
71+
nb::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
72+
nb::call_guard<nb::gil_scoped_release>());
7273
}
7374
} // namespace tensorrt_llm::nanobind::thop

cpp/tensorrt_llm/pybind/thop/bindings.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ void initBindings(pybind11::module_& m)
6464
py::arg("softmax_stats_tensor") = std::nullopt, py::arg("spec_decoding_bool_params"),
6565
py::arg("spec_decoding_tensor_params"), py::arg("sparse_kv_indices") = std::nullopt,
6666
py::arg("sparse_kv_offsets") = std::nullopt, py::arg("sparse_attn_indices") = std::nullopt,
67-
py::arg("sparse_attn_offsets") = std::nullopt, py::arg("sparse_mla_topk") = std::nullopt,
68-
py::arg("cu_q_seqlens") = std::nullopt, py::arg("cu_kv_seqlens") = std::nullopt,
69-
py::arg("fmha_scheduler_counter") = std::nullopt, py::arg("mla_bmm1_scale") = std::nullopt,
70-
py::arg("mla_bmm2_scale") = std::nullopt, py::arg("quant_q_buffer") = std::nullopt,
71-
"Multi-head attention operation", py::call_guard<py::gil_scoped_release>());
67+
py::arg("sparse_attn_offsets") = std::nullopt, py::arg("sparse_attn_indices_block_size"),
68+
py::arg("sparse_mla_topk") = std::nullopt, py::arg("cu_q_seqlens") = std::nullopt,
69+
py::arg("cu_kv_seqlens") = std::nullopt, py::arg("fmha_scheduler_counter") = std::nullopt,
70+
py::arg("mla_bmm1_scale") = std::nullopt, py::arg("mla_bmm2_scale") = std::nullopt,
71+
py::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
72+
py::call_guard<py::gil_scoped_release>());
7273
}
7374
} // namespace tensorrt_llm::pybind::thop

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,11 @@ class RunnerBase
8686
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
8787
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
8888
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
89-
torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk,
90-
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
91-
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
92-
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer) const
89+
torch::optional<torch::Tensor> sparse_attn_offsets, int64_t const sparse_attn_indices_block_size,
90+
int32_t const sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
91+
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
92+
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
93+
std::optional<torch::Tensor> quant_q_buffer) const
9394
= 0;
9495
};
9596

@@ -146,10 +147,11 @@ class Runner : public RunnerBase
146147
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
147148
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
148149
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
149-
torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk,
150-
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
151-
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
152-
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer) const override
150+
torch::optional<torch::Tensor> sparse_attn_offsets, int64_t const sparse_attn_indices_block_size,
151+
int32_t const sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
152+
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
153+
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
154+
std::optional<torch::Tensor> quant_q_buffer) const override
153155
{
154156
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
155157
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
@@ -395,6 +397,9 @@ class Runner : public RunnerBase
395397
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().data_ptr<int32_t>() : nullptr;
396398
op.mRuntimeSparseAttentionParams.sparse_attn_offsets
397399
= sparse_attn_offsets.has_value() ? sparse_attn_offsets.value().data_ptr<int32_t>() : nullptr;
400+
op.mRuntimeSparseAttentionParams.sparse_attn_indices_block_size = sparse_attn_indices_block_size;
401+
op.mRuntimeSparseAttentionParams.sparse_attn_indices_stride
402+
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().size(-1) : 0;
398403
if (op.isMLAEnabled() && op.mUseSparseAttention)
399404
{
400405
op.mRuntimeSparseAttentionParams.sparse_mla_topk = sparse_mla_topk;
@@ -589,10 +594,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
589594
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
590595
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
591596
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
592-
std::optional<int64_t> sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
593-
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
594-
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
595-
std::optional<torch::Tensor> quant_q_buffer)
597+
int64_t const sparse_attn_indices_block_size, std::optional<int64_t> sparse_mla_topk,
598+
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
599+
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
600+
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer)
596601
{
597602
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
598603
// Use these tensors to infer if the attention is using KV cache
@@ -847,8 +852,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
847852
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
848853
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
849854
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
850-
sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale,
851-
quant_q_buffer);
855+
sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
856+
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
852857
}
853858

854859
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
@@ -866,8 +871,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
866871
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
867872
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
868873
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
869-
sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale,
870-
quant_q_buffer);
874+
sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
875+
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
871876
}
872877

873878
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);

0 commit comments

Comments
 (0)