Skip to content

Conversation

@sarckk
Copy link
Collaborator

@sarckk sarckk commented Jun 17, 2025

Motivation

KV cache techniques like SwiftKV reduce computation required during prefill. This is harder to implement in V1 where the scheduler groups tokens for prefill and decode in the same batch. This PR adds instrumentation to support prefill compute savings in V1 in KV cache sharing setups where KV sharing is used such that certain tokens can be skipped during prefill (as KV target layers have already populated the necessary key/value tensors required for decoding).

Example

Let's say we have a 24 layer model where first 12 layers allocate their own KV caches and next 12 layers re-use the shared KV cache of its corresponding KV target layer. Then given input prompt sequence of N tokens, we can skip prefill for N-1 tokens for the last 12 layers, because the key/value tensors used for decoding is already populated in the KV caches of the first 12 layers. Because vLLM v1 scheduler does not distinguish prefill/decode and employs continuous batching, we can instead perform forward on the last 12 layers with a reduced input size.

For example, if we have request 0 and request 1 with 4 prompt tokens each, then we might have tokens batched as such:

<----r0---> <----r1---->
[0, 1, 2, 3, 4, 5, 6, 7]

For the first 12 self-attention layers, we can do forward with the full input [0, 1, 2, 3, 4, 5, 6, 7], while for the last 12 cross-attention layers, we can do forward with the last token for each request [3,7], as these are the only positions where valid logits are required to sample output tokens from.

Frontend changes

This PR adds a new --kv-sharing-skip-prefill arg which is added to the CacheConfig. This causes FlashAttention backend to compute an extra set of metadata assuming prefill skip, but changes are still required on model side to take advantage of this.

Attention metadata

Attention metadata needs to be changed to account for the different query offsets and max lengths in the shared KV layers for which N-1 tokens are skipped during prefill.

Correctness Test

Unit test show outputs are roughly equivalent with and without this optimization (exact numerics will differ as batched mm op will yield slightly different results depending on batch size)

pytest tests/v1/e2e/test_kv_sharing_truncated_prefill.py::test_kv_sharing_truncated_prefill

Perf comparison

Set up: single batch and input length of 8192. Using compile+piecewise cuda graph

TestQwen2ForCausalLM model forward trace with optimization (enable_kv_sharing_truncated_prefill=True)

second layer group takes 9.7ms

Screenshot 2025-07-02 at 21 00 10

Trace without optimization (enable_kv_sharing_truncated_prefill=False)

second layer group takes 16.6ms

Screenshot 2025-07-02 at 20 58 20

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jun 17, 2025
@houseroad houseroad requested a review from heheda12345 June 17, 2025 01:18
@sarckk sarckk requested a review from LucasWilkinson June 17, 2025 02:11
@heheda12345
Copy link
Collaborator

My concern is whether this optimization is too model specific. It works for models that the first k layers have kv cache. Does it work for models that every m layers share the same kv cache like Hunyuan?

@sarckk
Copy link
Collaborator Author

sarckk commented Jun 17, 2025

My concern is whether this optimization is too model specific. It works for models that the first k layers have kv cache. Does it work for models that every m layers share the same kv cache like Hunyuan?

It only works for the case where the first k layers have kv cache as you said. For general KV sharing cases, it should also apply for last N layers that reuse the KV cache (ie there are no other layers afterwards that have its own KV cache). So I agree it will not apply to a majority of models, but then I'm not sure if there is a better way to implement this kind of functionality.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a quick pass on this PR.

And I'm curious about your plan to support piecewise cuda graph. We need cuda graph for num_total_tokens in the first few layers, and num_decode_tokens in the following layers.

vllm/envs.py Outdated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to add it as a cli arg.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is not true for hunyuan-style kv sharing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added logic to detect which layers are 'eligible' for this prefill skip optimization

@mergify
Copy link

mergify bot commented Jun 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sarckk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added needs-rebase qwen Related to Qwen models labels Jun 18, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would really like to try to keep build signature of the metadata builders as simple as possible so hopefully we can create some nice unit testing infrastructure in the future. Do we really need to add decode_only_common_attn_metadata to the build call signature? can we make the kv sharing layers a different KVSpec and have separate build calls at this level:

for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
builder = self.attn_metadata_builders[kv_cache_group_id]
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
builder,
)
attn_metadata_i = (builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
))

we should probably be doing this for local attention too but that was added before we had the hybrid-KV cache (which enabled different build calls for different layer groups). We should probably migrate local attention to a scheme like this too

Copy link
Collaborator

@LucasWilkinson LucasWilkinson Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason we need to pass decode_only_common_attn_metadata as a separate arg; is there a reason we can't just use a different build call at the gpu model runner level? i.e. here-ish:

for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
builder = self.attn_metadata_builders[kv_cache_group_id]
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
builder,
)
attn_metadata_i = (builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I initially had a separate build() call at the model runner level, but I needed to set this as a property of attention metadata for all different backends, and they don't share a common schema. So I thought I could pass the info and let each backend decide what to do with it.

But I do agree that your approach is a better abstraction, will follow up on that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we move this logic into metadata builder?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved this logic to flash attn metadata builder

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I think I missed this so not sure what the code looked like at this point but I think ideally we would keep this common metadata manipulation outside of the metadata builders so we can naturally just support all the backends (assuming we can keep a clean build interface). This is important for blackwell where FlashInfer has the best perf. I actually want to do something similar for local-attention since that could also be done via pure CommonAttentionMetadata manipulation and would enable iRoPe for FlashInfer.

see: #19719 (comment)

@sarckk sarckk force-pushed the decode-only-attn branch 2 times, most recently from 541f2a5 to a9783c3 Compare July 2, 2025 23:28
@sarckk sarckk changed the title [V1] Perf optimization for layers reusing shared KV cache [V1] Perf optimization for early exit inference Jul 3, 2025
@sarckk sarckk changed the title [V1] Perf optimization for early exit inference [V1] Partial prefill skip for layers reusing shared KV cache Jul 3, 2025
@mergify mergify bot added the frontend label Jul 3, 2025
@sarckk sarckk force-pushed the decode-only-attn branch from 587c1d6 to 226edcf Compare July 3, 2025 17:13
@sarckk sarckk marked this pull request as ready for review July 3, 2025 17:13
@heheda12345
Copy link
Collaborator

May be unrelated to this PR. We also need an elegant way to skip preparing kv for layers that don't need them.

qkv, _ = self.qkv_proj(hidden_states)

q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)

@heheda12345
Copy link
Collaborator

@sarckk Here is a PR for v0 YOCO optimization. #20702 Though it is simplified due to ignoring chunked prefill and cuda graph, you can take a look and check whether there are anything you can learn.
The key logic in that PR:
https://github.com/vllm-project/vllm/blob/875e85bf6c00072d9d969dab4310d79aebf30471/vllm/attention/backends/differential_flash_attn.py#L757-L999
https://github.com/vllm-project/vllm/blob/875e85bf6c00072d9d969dab4310d79aebf30471/vllm/model_executor/models/phi4flash.py#L561-L571

sarckk added 10 commits July 11, 2025 13:50
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
@sarckk sarckk force-pushed the decode-only-attn branch from e171dd5 to 3cd2474 Compare July 11, 2025 21:18
@mergify mergify bot removed the needs-rebase label Jul 11, 2025
@mergify mergify bot added the tpu Related to Google TPUs label Jul 14, 2025
@mergify
Copy link

mergify bot commented Jul 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sarckk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@sarckk
Copy link
Collaborator Author

sarckk commented Jul 23, 2025

this PR is still being worked on, we are going to first decouple kv cache group and attention metadata builder to allow different layers to have different metadata builders

EDIT: see #21590 for updated PR

Signed-off-by: Yong Hoon Shin <[email protected]>
@github-actions
Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale Over 90 days of inactivity label Oct 27, 2025
@heheda12345
Copy link
Collaborator

redo in #22628

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

frontend needs-rebase qwen Related to Qwen models stale Over 90 days of inactivity tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants