@@ -176,6 +176,183 @@ __global__ void reshape_and_cache_kernel(
176176 }
177177}
178178
179+ // Grid: (num_blocks, num_heads).
180+ template <typename scalar_t >
181+ __global__ void gather_cached_kv_kernel (
182+ scalar_t * __restrict__ out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size]
183+ const scalar_t * __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
184+ const scalar_t * __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
185+ const int * __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
186+ const int * __restrict__ cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow
187+ const int num_seqs,
188+ const int max_num_blocks_per_seq,
189+ const int head_size,
190+ const int block_size) {
191+ // Each CUDA gird is mapped to (num_blocks, num_heads).
192+ const int block_idx = blockIdx .x ;
193+ const int num_blocks = gridDim .x ;
194+ const int head_idx = blockIdx .y ;
195+ const int num_heads = gridDim .y ;
196+ // Each CUDA block is responsible for (head_size, block_size).
197+ const int thread_idx = threadIdx .x ;
198+ const int num_threads = blockDim .x ;
199+ // in the original attention kernel, each thread group fetch x elements at a time.
200+ constexpr int x = 16 / sizeof (scalar_t );
201+
202+ // the index of the sequence this thread is working on.
203+ int seq_idx;
204+ // the index of the block in the sequence this thread is working on.
205+ int local_block_idx;
206+ // calculate the sequence index and block index in the sequence.
207+ int num_total_blocks = 0 ;
208+ #pragma unroll
209+ for (int i = 0 ; i < num_seqs; ++i) {
210+ int context_len = cu_seqlens_k[i + 1 ] - cu_seqlens_k[i];
211+ int num_blocks = (context_len + block_size - 1 ) / block_size;
212+ num_total_blocks += num_blocks;
213+ if (num_total_blocks > block_idx) {
214+ seq_idx = i;
215+ local_block_idx = block_idx - (num_total_blocks - num_blocks);
216+ break ;
217+ }
218+ }
219+ // const int context_len = cu_seqlens_k[seq_idx];
220+ // const int num_blocks = (context_len + block_size - 1) / block_size;
221+ const int * block_table = block_tables + seq_idx * max_num_blocks_per_seq;
222+ const int physical_block_number = block_table[local_block_idx];
223+
224+ // number of chunks handled by a CUDA block.
225+ const int n_chunks = (head_size * block_size + (num_threads - 1 )) / num_threads;
226+ const int physical_cache_offset = (physical_block_number * num_heads + head_idx) * head_size * block_size;
227+
228+ // The common output pointer base used by both key and value:
229+ scalar_t * common_out = out + (block_idx * block_size) * 3 * num_heads * head_size
230+ + head_idx * head_size;
231+ // key is the second tensor in QKV, so qkv_offset = 1
232+ scalar_t * key_out = common_out + 1 * num_heads * head_size;
233+ // value is the third tensor in QKV, so qkv_offset = 2
234+ scalar_t * value_out = common_out + 2 * num_heads * head_size;
235+
236+ // process key in chunks
237+ #pragma unroll
238+ for (int chunk_idx = 0 ; chunk_idx < n_chunks; ++chunk_idx) {
239+ const int offset = chunk_idx * num_threads + thread_idx;
240+ if (offset >= head_size * block_size) {
241+ break ;
242+ }
243+ // calculate offsets in [head_size/x, block_size, x]
244+ const int head_offset = offset / x / block_size;
245+ const int block_offset = offset / x % block_size;
246+ const int x_offset = offset % x;
247+
248+ const scalar_t * k_ptr = k_cache + physical_cache_offset + offset;
249+ scalar_t * out_ptr = key_out + block_offset * 3 * num_heads * head_size
250+ + head_offset * x + x_offset;
251+ *out_ptr = __ldg (k_ptr);
252+ }
253+
254+ // process value in chunks
255+ #pragma unroll
256+ for (int chunk_idx = 0 ; chunk_idx < n_chunks; ++chunk_idx) {
257+ const int offset = chunk_idx * num_threads + thread_idx;
258+ if (offset >= head_size * block_size) {
259+ break ;
260+ }
261+ // calculate offsets in [head_size, block_size]
262+ const int head_offset = offset / block_size;
263+ const int block_offset = offset % block_size;
264+
265+ const scalar_t * v_ptr = v_cache + physical_cache_offset + offset;
266+ scalar_t * out_ptr = value_out + block_offset * 3 * num_heads * head_size + head_offset;
267+ *out_ptr = __ldg (v_ptr);
268+ }
269+ }
270+
271+
272+ // Grid: (num_blocks, block_size).
273+ template <typename scalar_t >
274+ __global__ void gather_cached_kv_kernel_2 (
275+ scalar_t * __restrict__ out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size]
276+ const scalar_t * __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
277+ const scalar_t * __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
278+ const int * __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
279+ const int * __restrict__ cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow
280+ const int num_seqs,
281+ const int max_num_blocks_per_seq,
282+ const int num_heads,
283+ const int head_size) {
284+ // Each CUDA gird is mapped to (num_blocks, num_heads).
285+ const int block_idx = blockIdx .x ;
286+ const int num_blocks = gridDim .x ;
287+ const int block_offset = blockIdx .y ;
288+ const int block_size = gridDim .y ;
289+ // Each CUDA block is responsible for (head_size, block_size).
290+ const int thread_idx = threadIdx .x ;
291+ const int num_threads = blockDim .x ;
292+ // in the original attention kernel, each thread group fetch x elements at a time.
293+ constexpr int x = 16 / sizeof (scalar_t );
294+
295+ // the index of the sequence this thread is working on.
296+ int seq_idx;
297+ // the index of the block in the sequence this thread is working on.
298+ int local_block_idx;
299+ // calculate the sequence index and block index in the sequence.
300+ int num_total_blocks = 0 ;
301+ #pragma unroll
302+ for (int i = 0 ; i < num_seqs; ++i) {
303+ int context_len = cu_seqlens_k[i + 1 ] - cu_seqlens_k[i];
304+ int num_blocks = (context_len + block_size - 1 ) / block_size;
305+ num_total_blocks += num_blocks;
306+ if (num_total_blocks > block_idx) {
307+ seq_idx = i;
308+ local_block_idx = block_idx - (num_total_blocks - num_blocks);
309+ break ;
310+ }
311+ }
312+
313+ // const int context_len = cu_seqlens_k[seq_idx];
314+ // const int num_blocks = (context_len + block_size - 1) / block_size;
315+ const int * block_table = block_tables + seq_idx * max_num_blocks_per_seq;
316+ const int physical_block_number = block_table[local_block_idx];
317+ const int physical_cache_offset = physical_block_number * num_heads * head_size * block_size;
318+
319+ // The common output pointer base used by both key and value:
320+ scalar_t * common_out = out + (block_idx * block_size + block_offset) * 3 * num_heads * head_size;
321+ // key is the second tensor in QKV, so qkv_offset = 1
322+ scalar_t * key_out = common_out + 1 * num_heads * head_size;
323+ // value is the third tensor in QKV, so qkv_offset = 2
324+ scalar_t * value_out = common_out + 2 * num_heads * head_size;
325+
326+ // process key in chunks
327+ #pragma unroll
328+ for (int i = threadIdx .x ; i < num_heads * head_size; i += blockDim .x ) {
329+ // calculate offsets in [num_heads, head_size/x, x]
330+ const int head_idx = i / x / (head_size / x);
331+ const int head_offset = i / x % (head_size / x);
332+ const int x_offset = i % x;
333+
334+ const scalar_t * k_ptr = k_cache + physical_cache_offset
335+ + head_idx * (head_size/x) * block_size * x
336+ + head_offset * block_size * x
337+ + block_offset * x
338+ + x_offset;
339+ key_out[head_idx * head_size + head_offset * x + x_offset] = __ldg (k_ptr);
340+ }
341+
342+ // process value in chunks
343+ #pragma unroll
344+ for (int i = threadIdx .x ; i < num_heads * head_size; i += blockDim .x ) {
345+ // calculate offsets in [num_heads, head_size]
346+ const int head_idx = i / head_size;
347+ const int head_offset = i % head_size;
348+
349+ const scalar_t * v_ptr = v_cache + physical_cache_offset
350+ + i * block_size // equal to (head_idx * head_size + head_offset) * block_size
351+ + block_offset;
352+ value_out[i] = __ldg (v_ptr);
353+ }
354+ }
355+
179356} // namespace cacheflow
180357
181358void reshape_and_cache (
@@ -215,3 +392,96 @@ void reshape_and_cache(
215392 x);
216393 });
217394}
395+
396+ /*
397+ // same group of threads will be working on the same block
398+ void gather_cached_kv(
399+ torch::Tensor& qkv_out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size]
400+ torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
401+ torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
402+ torch::Tensor& cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow
403+ torch::Tensor& seqlens_k, // CPU version of 'cu_seqlens_k'
404+ torch::Tensor& block_tables) { // [num_seqs, max_num_blocks_per_seq]
405+ const int num_seqs = cu_seqlens_k.size(0) - 1;
406+ const int num_heads = value_cache.size(1);
407+ const int head_size = value_cache.size(2);
408+ const int block_size = value_cache.size(3);
409+ // const int x = key_cache.size(4);
410+ const int max_num_blocks_per_seq = block_tables.size(1);
411+ const int* context_lens_ptr = cu_seqlens_k.data_ptr<int>();
412+ const int* cpu_context_lens_ptr = seqlens_k.data_ptr<int>();
413+
414+ // calculate the total amount of blocks
415+ int num_total_blocks = 0;
416+ for (int i = 0; i < num_seqs; ++i) {
417+ int context_len = cpu_context_lens_ptr[i + 1] - cpu_context_lens_ptr[i];
418+ int num_blocks = (context_len + block_size - 1) / block_size;
419+ num_total_blocks += num_blocks;
420+ }
421+
422+ constexpr int NUM_THREADS = 256;
423+ dim3 grid(num_total_blocks, num_heads);
424+ dim3 block(NUM_THREADS);
425+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
426+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
427+ key_cache.scalar_type(),
428+ "gather_cached_kv_kernel",
429+ [&] {
430+ cacheflow::gather_cached_kv_kernel<scalar_t><<<grid, block, 0, stream>>>(
431+ qkv_out.data_ptr<scalar_t>(),
432+ key_cache.data_ptr<scalar_t>(),
433+ value_cache.data_ptr<scalar_t>(),
434+ block_tables.data_ptr<int>(),
435+ cu_seqlens_k.data_ptr<int>(),
436+ num_seqs,
437+ max_num_blocks_per_seq,
438+ head_size,
439+ block_size);
440+ });
441+ }
442+ */
443+
444+ // same group of threads will be working on the same block
445+ void gather_cached_kv (
446+ torch::Tensor& qkv_out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size]
447+ torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
448+ torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
449+ torch::Tensor& cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow
450+ torch::Tensor& seqlens_k, // CPU version of 'cu_seqlens_k'
451+ torch::Tensor& block_tables) { // [num_seqs, max_num_blocks_per_seq]
452+ const int num_seqs = cu_seqlens_k.size (0 ) - 1 ;
453+ const int num_heads = value_cache.size (1 );
454+ const int head_size = value_cache.size (2 );
455+ const int block_size = value_cache.size (3 );
456+ // const int x = key_cache.size(4);
457+ const int max_num_blocks_per_seq = block_tables.size (1 );
458+ const int * context_lens_ptr = cu_seqlens_k.data_ptr <int >();
459+ const int * cpu_context_lens_ptr = seqlens_k.data_ptr <int >();
460+
461+ // calculate the total amount of blocks
462+ int num_total_blocks = 0 ;
463+ for (int i = 0 ; i < num_seqs; ++i) {
464+ int context_len = cpu_context_lens_ptr[i + 1 ] - cpu_context_lens_ptr[i];
465+ int num_blocks = (context_len + block_size - 1 ) / block_size;
466+ num_total_blocks += num_blocks;
467+ }
468+
469+ dim3 grid (num_total_blocks, block_size);
470+ dim3 block (std::min (num_heads * head_size, 512 ));
471+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
472+ AT_DISPATCH_FLOATING_TYPES_AND_HALF (
473+ key_cache.scalar_type (),
474+ " gather_cached_kv_kernel_2" ,
475+ [&] {
476+ cacheflow::gather_cached_kv_kernel_2<scalar_t ><<<grid, block, 0 , stream>>> (
477+ qkv_out.data_ptr <scalar_t >(),
478+ key_cache.data_ptr <scalar_t >(),
479+ value_cache.data_ptr <scalar_t >(),
480+ block_tables.data_ptr <int >(),
481+ cu_seqlens_k.data_ptr <int >(),
482+ num_seqs,
483+ max_num_blocks_per_seq,
484+ num_heads,
485+ head_size);
486+ });
487+ }
0 commit comments