-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[OPTIMIZATION] Optimizes the single_query_cached_kv_attention kernel #420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OPTIMIZATION] Optimizes the single_query_cached_kv_attention kernel #420
Conversation
single_query_cached_kv_attention kernel
|
See #421 for a detailed description and analysis of this commit. |
|
Hi @naed90 overall LGTM, just had one small nitpick and looks like some formatting issues to address |
ty. |
|
@WoosukKwon @zhuohan123 hey, what do you think? |
|
Hey @naed90, thanks for submitting the PR and apologies for the late response. I was busy for the last few days. Will take a look your issue and PR today. |
@WoosukKwon bump :) |
|
Tested a bit on the latency side: Before optimization After optimization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your contribution! Left some small comments. We should be able to merge this after the changes.
csrc/attention/attention_kernels.cu
Outdated
| const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; | ||
| Q_vec q_vecs[NUM_VECS_PER_THREAD]; | ||
| __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; | ||
| if (thread_group_idx <= NUM_THREAD_GROUPS_LOWER_BOUND) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if seems redundant if we assume NUM_THREADS should is divisible by THREAD_GROUP_SIZE?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced with an assert.
Co-authored-by: Zhuohan Li <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thank you again for your hard work and detailed profiling!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bank conflict in loading q_vecs.
csrc/attention/attention_kernels.cu
Outdated
| q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); | ||
| for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS_LOWER_BOUND) { | ||
| const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; | ||
| q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have questions about loading query from gmem to q_vecs in smem. Threads within a thread group are neighbored in thread dim, writing q_vecs[thread_group_offset][i] will cause grouped threads accessing the same bank in smem (assume NUM_VECS_PER_THREAD * VEC_SIZE * sizeof(scalar_t) % 32 == 0). Can this be fixed by saving q_vecs in col-major, or is this a misunderstanding?
Instead of having each thread group fetch the query head (which causes 64x memory to be read), we have all threads in the block share the task of loading the query head. On the benchmark of running 1000 sequences through LLaMA13B on an A100 (80GB), this improves the throughput by 1.10x.