-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extend paged attention to support query_len>1 #8328
Conversation
page_indices, # [batch_size, pages_per_sequence] | ||
num_kv_pages_per_compute_block, | ||
num_queries_per_compute_block, | ||
use_kernel=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey @WoosukKwon, this is the integration point between vLLM and torch_xla. I'm thinking if vLLM can switch this flag use_kernel
perhaps by using some flags. I want to use the nonkernel version as a per baseline. Do you know if it possible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For dynamo, it's similar. The integration point is at def multi_queries_paged_attention_xla(
in the same file.
torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py
Outdated
Show resolved
Hide resolved
torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py
Show resolved
Hide resolved
q_index = q_blk_idx * num_queries_per_compute_block | ||
kv_index = kv_blk_idx * kv_seq_len_per_kv_compute_blk | ||
kv_len = lengths_ref[b] | ||
row_ids = (kv_len - query_len) + q_index + jax.lax.broadcasted_iota( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, we assume the input query corresponds to the last (q_len) of the input kv. For example, if the input q_len is 8, and kv_len is 24, we assume the query corresponds to the kv at index [16. 24), and applies the causal mask accordingly.
@WoosukKwon please let us know if this assumption is valid or nor for the use cases in vLLM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's the desired behavior. Thanks for checking it out with me!
This PR extends the existing paged attention kernel to support query_len>1. Additionally, it upgrades the flash attention from v1 to v2.
Test plan:
cc: @miladm