Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ __global__ void single_query_cached_kv_attention_kernel(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride) {
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS_LOWER_BOUND = NUM_THREADS / THREAD_GROUP_SIZE;
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
Expand Down Expand Up @@ -116,12 +117,15 @@ __global__ void single_query_cached_kv_attention_kernel(
// th vectors of the query, and so on.
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD];
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
if (thread_group_idx <= NUM_THREAD_GROUPS_LOWER_BOUND) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if seems redundant if we assume NUM_THREADS should is divisible by THREAD_GROUP_SIZE?

Copy link
Contributor Author

@naed90 naed90 Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with an assert.

#pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS_LOWER_BOUND) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have questions about loading query from gmem to q_vecs in smem. Threads within a thread group are neighbored in thread dim, writing q_vecs[thread_group_offset][i] will cause grouped threads accessing the same bank in smem (assume NUM_VECS_PER_THREAD * VEC_SIZE * sizeof(scalar_t) % 32 == 0). Can this be fixed by saving q_vecs in col-major, or is this a misunderstanding?

}
}
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs

// Memory planning.
extern __shared__ char shared_mem[];
Expand Down Expand Up @@ -169,7 +173,7 @@ __global__ void single_query_cached_kv_attention_kernel(

// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;

Expand Down