@@ -20,7 +20,7 @@ namespace tensorrt_llm
2020{
2121namespace kernels
2222{
23- template <int THREADS_PER_BLOCK>
23+ template <int THREADS_PER_BLOCK, int MAX_NUM_PAGES >
2424__global__ void gatherKvPageOffsetsKernel (
2525 int32_t * output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
2626 int32_t * output_seq_lengths, // [num_head_kv, batch_size]
@@ -32,23 +32,33 @@ __global__ void gatherKvPageOffsetsKernel(
3232 // Each CUDA block processes one sequence from the batch for one head.
3333 int32_t const head_idx = blockIdx .x ;
3434 int32_t const batch_idx = blockIdx .y ;
35+ int32_t const indices_block_size = sparse_params.sparse_attn_indices_block_size ;
3536 if (batch_idx >= batch_size)
3637 {
3738 return ;
3839 }
3940
40- // Shared memory for reduction.
41- __shared__ typename cub::BlockReduce<Pair, THREADS_PER_BLOCK>::TempStorage temp_storage;
41+ using BlockScan = cub::BlockScan<int32_t , THREADS_PER_BLOCK>;
42+ using BlockReduce = cub::BlockReduce<Pair, THREADS_PER_BLOCK>;
43+
44+ __shared__ typename BlockScan::TempStorage temp_storage_scan;
45+ __shared__ typename BlockReduce::TempStorage temp_storage_reduce;
46+
47+ __shared__ int32_t s_page_mask[MAX_NUM_PAGES];
48+ __shared__ int32_t s_cu_page_mask[MAX_NUM_PAGES];
49+ __shared__ int32_t s_scan_total; // Store total count from scan
4250
4351 // Get the range of sparse indices and the sequence length.
4452 int32_t const start_offset = sparse_params.sparse_attn_offsets [batch_idx];
4553 int32_t const end_offset = sparse_params.sparse_attn_offsets [batch_idx + 1 ];
46- int32_t const total_pages = sparse_params.sparse_attn_offsets [batch_size] ;
47- int32_t const num_sparse_pages = end_offset - start_offset;
54+ int32_t const sparse_attn_indices_stride = sparse_params.sparse_attn_indices_stride ;
55+ int32_t const num_sparse_indices = end_offset - start_offset;
4856 int32_t const original_seq_len = seq_lengths[batch_idx];
57+ int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1 ) / tokens_per_page;
58+ int32_t const page_loops = (ori_valid_pages + MAX_NUM_PAGES - 1 ) / MAX_NUM_PAGES;
4959
5060 // Get global sparse index.
51- int32_t const sparse_idx_global = head_idx * total_pages + start_offset;
61+ int32_t const sparse_idx_global = head_idx * sparse_attn_indices_stride + start_offset;
5262
5363 // Get the base memory offset. shape: [batch_size, 2, max_num_pages_per_seq]
5464 size_t const src_base_offset = (size_t ) batch_idx * 2 * max_num_pages_per_seq;
@@ -58,56 +68,119 @@ __global__ void gatherKvPageOffsetsKernel(
5868 int32_t local_max_page_index = -1 ;
5969 int32_t local_num_valid_pages = 0 ;
6070
61- // Perform the gather operation.
62- for (int32_t i = threadIdx .x ; i < num_sparse_pages; i += blockDim .x )
71+ int32_t src_page_idx_offset = 0 ;
72+ int32_t dst_page_idx_offset = 0 ;
73+ for (int32_t loop_idx = 0 ; loop_idx < page_loops; loop_idx++)
6374 {
64- // Get the source idx and offset.
65- int32_t const src_idx = sparse_params.sparse_attn_indices [sparse_idx_global + i];
66- if (src_idx < 0 )
75+ src_page_idx_offset = loop_idx * MAX_NUM_PAGES;
76+ int32_t loop_num_valid_pages = min (MAX_NUM_PAGES, ori_valid_pages - src_page_idx_offset);
77+ for (int32_t i = threadIdx .x ; i < MAX_NUM_PAGES; i += blockDim .x )
78+ {
79+ s_page_mask[i] = 0 ;
80+ }
81+ __syncthreads ();
82+
83+ for (int32_t i = threadIdx .x ; i < num_sparse_indices; i += blockDim .x )
6784 {
68- continue ;
85+ int32_t const src_idx = sparse_params.sparse_attn_indices [sparse_idx_global + i];
86+ int32_t const src_idx_start = src_idx * indices_block_size;
87+ int32_t const src_idx_end = min (src_idx_start + indices_block_size, original_seq_len);
88+ for (int32_t j = src_idx_start; j < src_idx_end; j++)
89+ {
90+ int32_t const src_page_idx = j / tokens_per_page;
91+ if (src_page_idx >= src_page_idx_offset && src_page_idx < src_page_idx_offset + loop_num_valid_pages)
92+ {
93+ atomicExch (&s_page_mask[src_page_idx - src_page_idx_offset], 1 );
94+ }
95+ }
6996 }
97+ __syncthreads ();
98+
99+ // Handle case when loop_num_valid_pages > blockDim.x by processing in chunks
100+ int32_t scan_offset = 0 ;
101+ int32_t const scan_chunks = (loop_num_valid_pages + blockDim .x - 1 ) / blockDim .x ;
70102
71- // Update the local max page index.
72- local_max_page_index = max (local_max_page_index, src_idx);
73- local_num_valid_pages++;
103+ for (int32_t chunk_idx = 0 ; chunk_idx < scan_chunks; chunk_idx++)
104+ {
105+ int32_t const chunk_start = chunk_idx * blockDim .x ;
106+ int32_t const chunk_size = min ((int32_t ) blockDim .x , loop_num_valid_pages - chunk_start);
107+
108+ int32_t thread_data = (threadIdx .x < chunk_size) ? s_page_mask[chunk_start + threadIdx .x ] : 0 ;
109+ int32_t thread_output;
110+ int32_t aggregate;
111+
112+ BlockScan (temp_storage_scan).ExclusiveSum (thread_data, thread_output, aggregate);
113+ __syncthreads ();
114+
115+ if (threadIdx .x < chunk_size)
116+ {
117+ s_cu_page_mask[chunk_start + threadIdx .x ] = thread_output + scan_offset;
118+ }
119+ __syncthreads ();
74120
75- // Get the source and destination offsets.
76- size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
77- size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
78- size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + i;
79- size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + i;
121+ // Update scan offset for next chunk
122+ scan_offset += aggregate;
123+ }
80124
81- // Perform the gather operation: read from the sparse location and write to the dense location.
82- output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
83- output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
125+ if (threadIdx .x == 0 )
126+ {
127+ s_scan_total = scan_offset;
128+ }
129+ __syncthreads ();
130+
131+ // Perform the gather operation.
132+ for (int32_t i = threadIdx .x ; i < loop_num_valid_pages; i += blockDim .x )
133+ {
134+ // Skip if the page is not valid.
135+ if (s_page_mask[i] == 0 )
136+ {
137+ continue ;
138+ }
139+
140+ int32_t const src_idx = src_page_idx_offset + i;
141+ int32_t const dst_idx = dst_page_idx_offset + s_cu_page_mask[i];
142+
143+ local_max_page_index = max (local_max_page_index, src_idx);
144+ local_num_valid_pages++;
145+
146+ size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
147+ size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
148+ size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + dst_idx;
149+ size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + dst_idx;
150+
151+ output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
152+ output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
153+ }
154+ __syncthreads ();
155+
156+ // Update dst offset using the total count from scan
157+ dst_page_idx_offset += s_scan_total;
84158 }
85159
86160 // Reduce the local max page indices and number of valid pages.
87161 Pair local_pair = {local_max_page_index, local_num_valid_pages};
88- Pair result = cub:: BlockReduce<Pair, THREADS_PER_BLOCK>(temp_storage ).Reduce (local_pair, PairReduceOp ());
162+ Pair result = BlockReduce (temp_storage_reduce ).Reduce (local_pair, PairReduceOp ());
89163
90164 // Update sequence length for this head and batch.
91165 if (threadIdx .x == 0 )
92166 {
93167 int32_t const max_page_index = result.max_val ;
94168 int32_t const num_valid_pages = result.sum_val ;
95- int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1 ) / tokens_per_page;
96169 size_t const seq_len_offset = (size_t ) head_idx * batch_size + batch_idx;
170+ int32_t seq_len = 0 ;
97171 if (num_valid_pages > 0 )
98172 {
99- int32_t seq_len = original_seq_len - (ori_valid_pages - num_valid_pages) * tokens_per_page;
100- int32_t seq_len_remain = original_seq_len % tokens_per_page;
101- if (max_page_index != ori_valid_pages - 1 && seq_len_remain != 0 )
173+ if (max_page_index == ori_valid_pages - 1 )
102174 {
103- seq_len += tokens_per_page - seq_len_remain;
175+ seq_len = (num_valid_pages - 1 ) * tokens_per_page
176+ + (original_seq_len - (ori_valid_pages - 1 ) * tokens_per_page);
177+ }
178+ else
179+ {
180+ seq_len = num_valid_pages * tokens_per_page;
104181 }
105- output_seq_lengths[seq_len_offset] = seq_len;
106- }
107- else
108- {
109- output_seq_lengths[seq_len_offset] = 0 ;
110182 }
183+ output_seq_lengths[seq_len_offset] = seq_len;
111184 }
112185}
113186
@@ -121,11 +194,8 @@ void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, int32_t* output_
121194 dim3 grid (num_head_kv, batch_size, 1 );
122195 // The block.
123196 dim3 block (256 , 1 , 1 );
124- // Shared memory size.
125- size_t smem_size = sizeof (Pair) * 256 ;
126197
127- // Launch the kernel.
128- gatherKvPageOffsetsKernel<256 ><<<grid, block, smem_size, stream>>> (output_kv_page_offsets, output_seq_lengths,
198+ gatherKvPageOffsetsKernel<256 , 512 ><<<grid, block, 0 , stream>>> (output_kv_page_offsets, output_seq_lengths,
129199 kv_page_offsets, seq_lengths, sparse_params, batch_size, tokens_per_page, max_num_pages_per_seq);
130200}
131201} // namespace kernels
0 commit comments