-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[OPTIMIZATION] Optimizes the single_query_cached_kv_attention kernel #420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -83,6 +83,7 @@ __global__ void single_query_cached_kv_attention_kernel( | |
| const float* __restrict__ alibi_slopes, // [num_heads] | ||
| const int q_stride) { | ||
| constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); | ||
| constexpr int NUM_THREAD_GROUPS_LOWER_BOUND = NUM_THREADS / THREAD_GROUP_SIZE; | ||
| constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; | ||
| constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; | ||
| const int thread_idx = threadIdx.x; | ||
|
|
@@ -116,12 +117,15 @@ __global__ void single_query_cached_kv_attention_kernel( | |
| // th vectors of the query, and so on. | ||
| // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. | ||
| const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; | ||
| Q_vec q_vecs[NUM_VECS_PER_THREAD]; | ||
| __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; | ||
| if (thread_group_idx <= NUM_THREAD_GROUPS_LOWER_BOUND) { | ||
|
||
| #pragma unroll | ||
| for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { | ||
| const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; | ||
| q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); | ||
| for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS_LOWER_BOUND) { | ||
| const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; | ||
| q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); | ||
|
||
| } | ||
| } | ||
| __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs | ||
|
|
||
| // Memory planning. | ||
| extern __shared__ char shared_mem[]; | ||
|
|
@@ -169,7 +173,7 @@ __global__ void single_query_cached_kv_attention_kernel( | |
|
|
||
| // Compute dot product. | ||
| // This includes a reduction across the threads in the same thread group. | ||
| float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs); | ||
| float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs); | ||
| // Add the ALiBi bias if slopes are given. | ||
| qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0; | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.