Skip to content

Commit e21845e

Browse files
committed
optimize
optimize with shared memory better number of threads update test temp disable test update
1 parent c267b1a commit e21845e

File tree

4 files changed

+433
-60
lines changed

4 files changed

+433
-60
lines changed

benchmark/benchmark_attention.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from flash_attn.flash_attn_interface import _flash_attn_forward
77
import torch
88

9-
from cacheflow import attention_ops
9+
from cacheflow import attention_ops, cache_ops
1010

1111

1212
def benchmark(name, f, num_warmup = 10, num_iters = 100):
@@ -43,7 +43,7 @@ def benchmark_multi_query_cached_kv_attention(
4343
num_total_tokens = cu_query_lens[-1]
4444
qkv = torch.randn(
4545
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
46-
query, _, _ = qkv.unbind(dim=1)
46+
query, key, value = qkv.unbind(dim=1) # NOTE: this will not make a copy.
4747

4848
# Create key and value cache.
4949
x = 16 // torch.tensor([], dtype=dtype).element_size()
@@ -72,21 +72,53 @@ def benchmark_multi_query_cached_kv_attention(
7272
scale = float(1.0 / (head_size ** 0.5))
7373
output = torch.empty(
7474
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
75+
76+
num_kv_tokens = sum(context_lens)
77+
cu_context_lens = [0]
78+
for context_len in context_lens:
79+
cu_context_lens.append(cu_context_lens[-1] + context_len)
80+
cpu_context_lens = torch.tensor(cu_context_lens, dtype=torch.int, device='cpu')
81+
cu_context_lens = cpu_context_lens.cuda()
82+
ref_output = torch.empty_like(output)
7583

7684
# Run our implementation.
7785
def run_ours():
78-
attention_ops.multi_query_cached_kv_attention(
79-
cu_query_lens,
80-
output,
81-
query,
86+
cache_ops.gather_cached_kv(
87+
qkv,
8288
key_cache,
8389
value_cache,
84-
scale,
90+
cu_context_lens,
91+
cpu_context_lens,
8592
block_tables,
86-
context_len_tensor,
87-
block_size,
93+
)
94+
95+
_flash_attn_forward(
96+
query,
97+
key,
98+
value,
99+
ref_output,
100+
cu_query_lens,
101+
cu_context_lens,
102+
max(query_lens),
88103
max_context_len,
104+
dropout_p=0.0,
105+
softmax_scale=scale,
106+
causal=True,
107+
return_softmax=False,
89108
)
109+
110+
# attention_ops.multi_query_cached_kv_attention(
111+
# cu_query_lens,
112+
# output,
113+
# query,
114+
# key_cache,
115+
# value_cache,
116+
# scale,
117+
# block_tables,
118+
# context_len_tensor,
119+
# block_size,
120+
# max_context_len,
121+
# )
90122
benchmark('Ours', run_ours)
91123

92124
# Upper bound: Flash attention.

csrc/cache.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ void reshape_and_cache(
2020
torch::Tensor& value_cache,
2121
torch::Tensor& slot_mapping);
2222

23+
void gather_cached_kv(
24+
torch::Tensor& qkv_out,
25+
torch::Tensor& key_cache,
26+
torch::Tensor& value_cache,
27+
torch::Tensor& cu_seqlens_k,
28+
torch::Tensor& seqlens_k,
29+
torch::Tensor& block_tables);
30+
2331
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
2432
m.def(
2533
"swap_blocks",
@@ -33,4 +41,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3341
"reshape_and_cache",
3442
&reshape_and_cache,
3543
"Reshape the key and value tensors and cache them");
44+
m.def(
45+
"gather_cached_kv",
46+
&gather_cached_kv,
47+
"Gather key and value from the cache into contiguous QKV tensors");
3648
}

csrc/cache_kernels.cu

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,183 @@ __global__ void reshape_and_cache_kernel(
176176
}
177177
}
178178

179+
// Grid: (num_blocks, num_heads).
180+
template<typename scalar_t>
181+
__global__ void gather_cached_kv_kernel(
182+
scalar_t* __restrict__ out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size]
183+
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
184+
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
185+
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
186+
const int* __restrict__ cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow
187+
const int num_seqs,
188+
const int max_num_blocks_per_seq,
189+
const int head_size,
190+
const int block_size) {
191+
// Each CUDA gird is mapped to (num_blocks, num_heads).
192+
const int block_idx = blockIdx.x;
193+
const int num_blocks = gridDim.x;
194+
const int head_idx = blockIdx.y;
195+
const int num_heads = gridDim.y;
196+
// Each CUDA block is responsible for (head_size, block_size).
197+
const int thread_idx = threadIdx.x;
198+
const int num_threads = blockDim.x;
199+
// in the original attention kernel, each thread group fetch x elements at a time.
200+
constexpr int x = 16 / sizeof(scalar_t);
201+
202+
// the index of the sequence this thread is working on.
203+
int seq_idx;
204+
// the index of the block in the sequence this thread is working on.
205+
int local_block_idx;
206+
// calculate the sequence index and block index in the sequence.
207+
int num_total_blocks = 0;
208+
#pragma unroll
209+
for (int i = 0; i < num_seqs; ++i) {
210+
int context_len = cu_seqlens_k[i + 1] - cu_seqlens_k[i];
211+
int num_blocks = (context_len + block_size - 1) / block_size;
212+
num_total_blocks += num_blocks;
213+
if (num_total_blocks > block_idx) {
214+
seq_idx = i;
215+
local_block_idx = block_idx - (num_total_blocks - num_blocks);
216+
break;
217+
}
218+
}
219+
// const int context_len = cu_seqlens_k[seq_idx];
220+
// const int num_blocks = (context_len + block_size - 1) / block_size;
221+
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
222+
const int physical_block_number = block_table[local_block_idx];
223+
224+
// number of chunks handled by a CUDA block.
225+
const int n_chunks = (head_size * block_size + (num_threads - 1)) / num_threads;
226+
const int physical_cache_offset = (physical_block_number * num_heads + head_idx) * head_size * block_size;
227+
228+
// The common output pointer base used by both key and value:
229+
scalar_t* common_out = out + (block_idx * block_size) * 3 * num_heads * head_size
230+
+ head_idx * head_size;
231+
// key is the second tensor in QKV, so qkv_offset = 1
232+
scalar_t* key_out = common_out + 1 * num_heads * head_size;
233+
// value is the third tensor in QKV, so qkv_offset = 2
234+
scalar_t* value_out = common_out + 2 * num_heads * head_size;
235+
236+
// process key in chunks
237+
#pragma unroll
238+
for (int chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
239+
const int offset = chunk_idx * num_threads + thread_idx;
240+
if (offset >= head_size * block_size) {
241+
break;
242+
}
243+
// calculate offsets in [head_size/x, block_size, x]
244+
const int head_offset = offset / x / block_size;
245+
const int block_offset = offset / x % block_size;
246+
const int x_offset = offset % x;
247+
248+
const scalar_t* k_ptr = k_cache + physical_cache_offset + offset;
249+
scalar_t* out_ptr = key_out + block_offset * 3 * num_heads * head_size
250+
+ head_offset * x + x_offset;
251+
*out_ptr = __ldg(k_ptr);
252+
}
253+
254+
// process value in chunks
255+
#pragma unroll
256+
for (int chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
257+
const int offset = chunk_idx * num_threads + thread_idx;
258+
if (offset >= head_size * block_size) {
259+
break;
260+
}
261+
// calculate offsets in [head_size, block_size]
262+
const int head_offset = offset / block_size;
263+
const int block_offset = offset % block_size;
264+
265+
const scalar_t* v_ptr = v_cache + physical_cache_offset + offset;
266+
scalar_t* out_ptr = value_out + block_offset * 3 * num_heads * head_size + head_offset;
267+
*out_ptr = __ldg(v_ptr);
268+
}
269+
}
270+
271+
272+
// Grid: (num_blocks, block_size).
273+
template<typename scalar_t>
274+
__global__ void gather_cached_kv_kernel_2(
275+
scalar_t* __restrict__ out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size]
276+
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
277+
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
278+
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
279+
const int* __restrict__ cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow
280+
const int num_seqs,
281+
const int max_num_blocks_per_seq,
282+
const int num_heads,
283+
const int head_size) {
284+
// Each CUDA gird is mapped to (num_blocks, num_heads).
285+
const int block_idx = blockIdx.x;
286+
const int num_blocks = gridDim.x;
287+
const int block_offset = blockIdx.y;
288+
const int block_size = gridDim.y;
289+
// Each CUDA block is responsible for (head_size, block_size).
290+
const int thread_idx = threadIdx.x;
291+
const int num_threads = blockDim.x;
292+
// in the original attention kernel, each thread group fetch x elements at a time.
293+
constexpr int x = 16 / sizeof(scalar_t);
294+
295+
// the index of the sequence this thread is working on.
296+
int seq_idx;
297+
// the index of the block in the sequence this thread is working on.
298+
int local_block_idx;
299+
// calculate the sequence index and block index in the sequence.
300+
int num_total_blocks = 0;
301+
#pragma unroll
302+
for (int i = 0; i < num_seqs; ++i) {
303+
int context_len = cu_seqlens_k[i + 1] - cu_seqlens_k[i];
304+
int num_blocks = (context_len + block_size - 1) / block_size;
305+
num_total_blocks += num_blocks;
306+
if (num_total_blocks > block_idx) {
307+
seq_idx = i;
308+
local_block_idx = block_idx - (num_total_blocks - num_blocks);
309+
break;
310+
}
311+
}
312+
313+
// const int context_len = cu_seqlens_k[seq_idx];
314+
// const int num_blocks = (context_len + block_size - 1) / block_size;
315+
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
316+
const int physical_block_number = block_table[local_block_idx];
317+
const int physical_cache_offset = physical_block_number * num_heads * head_size * block_size;
318+
319+
// The common output pointer base used by both key and value:
320+
scalar_t* common_out = out + (block_idx * block_size + block_offset) * 3 * num_heads * head_size;
321+
// key is the second tensor in QKV, so qkv_offset = 1
322+
scalar_t* key_out = common_out + 1 * num_heads * head_size;
323+
// value is the third tensor in QKV, so qkv_offset = 2
324+
scalar_t* value_out = common_out + 2 * num_heads * head_size;
325+
326+
// process key in chunks
327+
#pragma unroll
328+
for (int i = threadIdx.x; i < num_heads * head_size; i += blockDim.x) {
329+
// calculate offsets in [num_heads, head_size/x, x]
330+
const int head_idx = i / x / (head_size / x);
331+
const int head_offset = i / x % (head_size / x);
332+
const int x_offset = i % x;
333+
334+
const scalar_t* k_ptr = k_cache + physical_cache_offset
335+
+ head_idx * (head_size/x) * block_size * x
336+
+ head_offset * block_size * x
337+
+ block_offset * x
338+
+ x_offset;
339+
key_out[head_idx * head_size + head_offset * x + x_offset] = __ldg(k_ptr);
340+
}
341+
342+
// process value in chunks
343+
#pragma unroll
344+
for (int i = threadIdx.x; i < num_heads * head_size; i += blockDim.x) {
345+
// calculate offsets in [num_heads, head_size]
346+
const int head_idx = i / head_size;
347+
const int head_offset = i % head_size;
348+
349+
const scalar_t* v_ptr = v_cache + physical_cache_offset
350+
+ i * block_size // equal to (head_idx * head_size + head_offset) * block_size
351+
+ block_offset;
352+
value_out[i] = __ldg(v_ptr);
353+
}
354+
}
355+
179356
} // namespace cacheflow
180357

181358
void reshape_and_cache(
@@ -215,3 +392,96 @@ void reshape_and_cache(
215392
x);
216393
});
217394
}
395+
396+
/*
397+
// same group of threads will be working on the same block
398+
void gather_cached_kv(
399+
torch::Tensor& qkv_out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size]
400+
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
401+
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
402+
torch::Tensor& cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow
403+
torch::Tensor& seqlens_k, // CPU version of 'cu_seqlens_k'
404+
torch::Tensor& block_tables) { // [num_seqs, max_num_blocks_per_seq]
405+
const int num_seqs = cu_seqlens_k.size(0) - 1;
406+
const int num_heads = value_cache.size(1);
407+
const int head_size = value_cache.size(2);
408+
const int block_size = value_cache.size(3);
409+
// const int x = key_cache.size(4);
410+
const int max_num_blocks_per_seq = block_tables.size(1);
411+
const int* context_lens_ptr = cu_seqlens_k.data_ptr<int>();
412+
const int* cpu_context_lens_ptr = seqlens_k.data_ptr<int>();
413+
414+
// calculate the total amount of blocks
415+
int num_total_blocks = 0;
416+
for (int i = 0; i < num_seqs; ++i) {
417+
int context_len = cpu_context_lens_ptr[i + 1] - cpu_context_lens_ptr[i];
418+
int num_blocks = (context_len + block_size - 1) / block_size;
419+
num_total_blocks += num_blocks;
420+
}
421+
422+
constexpr int NUM_THREADS = 256;
423+
dim3 grid(num_total_blocks, num_heads);
424+
dim3 block(NUM_THREADS);
425+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
426+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
427+
key_cache.scalar_type(),
428+
"gather_cached_kv_kernel",
429+
[&] {
430+
cacheflow::gather_cached_kv_kernel<scalar_t><<<grid, block, 0, stream>>>(
431+
qkv_out.data_ptr<scalar_t>(),
432+
key_cache.data_ptr<scalar_t>(),
433+
value_cache.data_ptr<scalar_t>(),
434+
block_tables.data_ptr<int>(),
435+
cu_seqlens_k.data_ptr<int>(),
436+
num_seqs,
437+
max_num_blocks_per_seq,
438+
head_size,
439+
block_size);
440+
});
441+
}
442+
*/
443+
444+
// same group of threads will be working on the same block
445+
void gather_cached_kv(
446+
torch::Tensor& qkv_out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size]
447+
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
448+
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
449+
torch::Tensor& cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow
450+
torch::Tensor& seqlens_k, // CPU version of 'cu_seqlens_k'
451+
torch::Tensor& block_tables) { // [num_seqs, max_num_blocks_per_seq]
452+
const int num_seqs = cu_seqlens_k.size(0) - 1;
453+
const int num_heads = value_cache.size(1);
454+
const int head_size = value_cache.size(2);
455+
const int block_size = value_cache.size(3);
456+
// const int x = key_cache.size(4);
457+
const int max_num_blocks_per_seq = block_tables.size(1);
458+
const int* context_lens_ptr = cu_seqlens_k.data_ptr<int>();
459+
const int* cpu_context_lens_ptr = seqlens_k.data_ptr<int>();
460+
461+
// calculate the total amount of blocks
462+
int num_total_blocks = 0;
463+
for (int i = 0; i < num_seqs; ++i) {
464+
int context_len = cpu_context_lens_ptr[i + 1] - cpu_context_lens_ptr[i];
465+
int num_blocks = (context_len + block_size - 1) / block_size;
466+
num_total_blocks += num_blocks;
467+
}
468+
469+
dim3 grid(num_total_blocks, block_size);
470+
dim3 block(std::min(num_heads * head_size, 512));
471+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
472+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
473+
key_cache.scalar_type(),
474+
"gather_cached_kv_kernel_2",
475+
[&] {
476+
cacheflow::gather_cached_kv_kernel_2<scalar_t><<<grid, block, 0, stream>>>(
477+
qkv_out.data_ptr<scalar_t>(),
478+
key_cache.data_ptr<scalar_t>(),
479+
value_cache.data_ptr<scalar_t>(),
480+
block_tables.data_ptr<int>(),
481+
cu_seqlens_k.data_ptr<int>(),
482+
num_seqs,
483+
max_num_blocks_per_seq,
484+
num_heads,
485+
head_size);
486+
});
487+
}

0 commit comments

Comments
 (0)