Skip to content

[WIP] Support for cached multi-query attention towards speculative decoding #1679

Closed
skrider wants to merge 14 commits intovllm-project:mainfrom
skrider:cached-mqa
Closed

[WIP] Support for cached multi-query attention towards speculative decoding #1679
skrider wants to merge 14 commits intovllm-project:mainfrom
skrider:cached-mqa

Conversation

@skrider
Copy link
Copy Markdown
Contributor

@skrider skrider commented Nov 16, 2023

Initial prototype of cached multi-query attention that takes advantage of implementation details of the single query cached attention kernel to adapt it to the multi-query setting.

Given n sequences with maximum draft length of k to be verified, greedily caches all keys and values, then calls paged_attention on n * k query vectors, "symbolically linking" the KV caches of drafts of the same sequence to the original, masking out "future" tokens by interpolating the sequence_len passed to paged attention kernel from context_len to context_len + draft_len.

While this kernel has support for dynamic draft lengths, this is facilitated somewhat inefficiently by masking rather than by dynamic shape. Potential room for improvement.

Performance has yet to be profiled. The intention behind this PR is to serve as a reference implementation against which a more performant MQA kernel can be developed.

@beginlner
Copy link
Copy Markdown
Contributor

beginlner commented Nov 23, 2023

I've made a pull request to flash-attention that enables support for blocked KV cache in flash-decoding which supports MQA. The performance is nearly identical to the original. You might want to check it out.
Dao-AILab/flash-attention#678

@skrider
Copy link
Copy Markdown
Contributor Author

skrider commented Nov 27, 2023

@Lvjinhong
Copy link
Copy Markdown

@beginlner thanks for the info. Reading https://github.com/microsoft/DeepSpeed-Kernels/blob/main/dskernels/inf_flash_attn/blocked_flash/flash_fwd_kernel.h as well.

So far, is there any progress on enabling speculative decoding for vLLM? Additionally, I'm wondering if the implementation of this kernel might result in increased GPU memory usage.

@Lvjinhong
Copy link
Copy Markdown

When can this branch be merged? In the version I am currently using, there is:

op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
                (is_hip()) else None,

Is the Flash operation supported only for HIP?

@skrider skrider closed this Feb 28, 2024
WeNeedMoreCode pushed a commit to WeNeedMoreCode/vllm that referenced this pull request Dec 15, 2025
Make model_runner_v1 more readable

- vLLM version: v0.9.2
- vLLM main:
vllm-project@baed180

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants