Skip to content

[Core][DSV4] Skip caching SWA blocks that can never serve a prefix-cache hit#42258

Merged
jeejeelee merged 1 commit into
vllm-project:mainfrom
ivanium:fix/dsv4-mask-blocks
May 15, 2026
Merged

[Core][DSV4] Skip caching SWA blocks that can never serve a prefix-cache hit#42258
jeejeelee merged 1 commit into
vllm-project:mainfrom
ivanium:fix/dsv4-mask-blocks

Conversation

@ivanium
Copy link
Copy Markdown
Collaborator

@ivanium ivanium commented May 11, 2026

Purpose

DeepSeek-V4 pairs full-attention layers with SWA layers with different block sizes and window sizes. The full attn layers have block size of 256, while SWA layers and compressors have block sizes of 64, 4, or 8.

  • Within each 256-aligned segment, only the trailing tail = ceil((sliding_window - 1) / block_size) blocks are reachable by SWA's right-to-left scan.
  • Hash-blocks past the last 256-aligned boundary are unreachable because hits are always lcm-aligned.

This PR drops the unreachable blocks at cache-time via a new alignment_tokens kwarg threaded through cache_blocks -> SlidingWindowManager._cache_block_mask -> BlockPool.cache_full_blocks. Non-hybrid coordinators pass alignment_tokens=None and hit a fast path identical to the existing behavior.

Test Plan

New tests in tests/v1/core/test_prefix_caching.py:

  • test_hybrid_cache_blocks_swa_tail_window_only — full-attn block_size=32, SWA block_size=8, sliding_window=8 (lcm=32, tail=1, per_segment=4). After caching 8 SWA hash-blocks, asserts only hashes 3 and 7 are in the prefix-cache hash map; 0–2 and 4–6 are not.
  • test_hybrid_cache_blocks_clamped_to_lcm — full-attn block_size=32, SWA block_size=16, sliding_window=32. After caching 7 SWA hash-blocks (112 tokens), asserts hashes 0–5 are cached and hash 6 (past the last lcm boundary) is not.

Test Result

Passed. Claude is used to generate test cases but I have reviewed them.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added the v1 label May 11, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes prefix caching for hybrid models by introducing a block masking mechanism, ensuring that only blocks capable of serving future cache hits (such as the tail window in Sliding Window Attention) are stored. Key changes include adding a block_mask to the BlockPool, implementing LCM-aligned caching in the KVCacheCoordinator, and updating cache managers to support sparse hit semantics. Review feedback highlighted potential assertion failures and memory leaks when handling shared physical blocks in models like DeepSeek-V4. Additionally, a bug was identified where token_ids in BlockStored events are not correctly filtered when blocks are masked, which could impact distributed caching systems.

Comment thread vllm/v1/core/block_pool.py
Comment on lines +306 to +307
if block_mask is not None and not block_mask[i - num_cached_blocks]:
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While extra_keys_list is correctly filtered by the block_mask, the token_ids slice passed to the BlockStored event (line 319) is not filtered. This results in a mismatch between the number of hashes and the number of tokens in the event when blocks are masked out (e.g., in SWA tail-window caching). This will likely break distributed caching consumers or offloading mechanisms that rely on these events. Please ensure token_ids is filtered to only include tokens for the blocks actually being cached.

Comment thread tests/v1/core/test_prefix_caching.py
@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label May 13, 2026
Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ivanium!

Looks like we should rebase now

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
@ivanium ivanium force-pushed the fix/dsv4-mask-blocks branch from 289c58c to c106593 Compare May 14, 2026 22:09
@jeejeelee jeejeelee merged commit 4b364f8 into vllm-project:main May 15, 2026
56 checks passed
jasl pushed a commit to jasl/vllm that referenced this pull request May 15, 2026
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip
caching SWA blocks that can never serve a prefix-cache hit.  When
Eagle/MTP speculative decoding is active the mask is too aggressive —
it skips blocks that eagle's modified lookup actually needs, resulting
in 0% prefix cache hit rate.

Eagle changes the SWA hit logic in two ways:
  1. sliding_window_contiguous_blocks += 1 (needs one extra block)
  2. post_pop_blocks = i (instead of i+1), shifting alignment

Fix: detect SWA managers inside eagle attention groups at coordinator
init time and disable the cache block mask for them.

Signed-off-by: Alex Bilichenko <abilichenko@gmail.com>
(cherry picked from commit b90c495)
Signed-off-by: jasl <jasl9187@hotmail.com>
jasl pushed a commit to jasl/vllm that referenced this pull request May 16, 2026
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip
caching SWA blocks that can never serve a prefix-cache hit.  When
Eagle/MTP speculative decoding is active the mask is too aggressive —
it skips blocks that eagle's modified lookup actually needs, resulting
in 0% prefix cache hit rate.

Eagle changes the SWA hit logic in two ways:
  1. sliding_window_contiguous_blocks += 1 (needs one extra block)
  2. post_pop_blocks = i (instead of i+1), shifting alignment

Fix: detect SWA managers inside eagle attention groups at coordinator
init time and disable the cache block mask for them.

Signed-off-by: Alex Bilichenko <abilichenko@gmail.com>
(cherry picked from commit b90c495)
Signed-off-by: jasl <jasl9187@hotmail.com>
jasl pushed a commit to jasl/vllm that referenced this pull request May 16, 2026
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip
caching SWA blocks that can never serve a prefix-cache hit.  When
Eagle/MTP speculative decoding is active the mask is too aggressive —
it skips blocks that eagle's modified lookup actually needs, resulting
in 0% prefix cache hit rate.

Eagle changes the SWA hit logic in two ways:
  1. sliding_window_contiguous_blocks += 1 (needs one extra block)
  2. post_pop_blocks = i (instead of i+1), shifting alignment

Fix: detect SWA managers inside eagle attention groups at coordinator
init time and disable the cache block mask for them.

Signed-off-by: Alex Bilichenko <abilichenko@gmail.com>
(cherry picked from commit b90c495)
Signed-off-by: jasl <jasl9187@hotmail.com>
omerpaz95 pushed a commit to omerpaz95/vllm that referenced this pull request May 18, 2026
…che hit (vllm-project#42258)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
omerpaz95 pushed a commit to omerpaz95/vllm that referenced this pull request May 18, 2026
…che hit (vllm-project#42258)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
jasl pushed a commit to jasl/vllm that referenced this pull request May 18, 2026
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip
caching SWA blocks that can never serve a prefix-cache hit.  When
Eagle/MTP speculative decoding is active the mask is too aggressive —
it skips blocks that eagle's modified lookup actually needs, resulting
in 0% prefix cache hit rate.

Eagle changes the SWA hit logic in two ways:
  1. sliding_window_contiguous_blocks += 1 (needs one extra block)
  2. post_pop_blocks = i (instead of i+1), shifting alignment

Fix: detect SWA managers inside eagle attention groups at coordinator
init time and disable the cache block mask for them.

Signed-off-by: Alex Bilichenko <abilichenko@gmail.com>
(cherry picked from commit b90c495)
Signed-off-by: jasl <jasl9187@hotmail.com>
jasl pushed a commit to jasl/vllm that referenced this pull request May 19, 2026
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip
caching SWA blocks that can never serve a prefix-cache hit.  When
Eagle/MTP speculative decoding is active the mask is too aggressive —
it skips blocks that eagle's modified lookup actually needs, resulting
in 0% prefix cache hit rate.

Eagle changes the SWA hit logic in two ways:
  1. sliding_window_contiguous_blocks += 1 (needs one extra block)
  2. post_pop_blocks = i (instead of i+1), shifting alignment

Fix: detect SWA managers inside eagle attention groups at coordinator
init time and disable the cache block mask for them.

Signed-off-by: Alex Bilichenko <abilichenko@gmail.com>
(cherry picked from commit b90c495)
Signed-off-by: jasl <jasl9187@hotmail.com>
jasl pushed a commit to jasl/vllm that referenced this pull request May 19, 2026
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip
caching SWA blocks that can never serve a prefix-cache hit.  When
Eagle/MTP speculative decoding is active the mask is too aggressive —
it skips blocks that eagle's modified lookup actually needs, resulting
in 0% prefix cache hit rate.

Eagle changes the SWA hit logic in two ways:
  1. sliding_window_contiguous_blocks += 1 (needs one extra block)
  2. post_pop_blocks = i (instead of i+1), shifting alignment

Fix: detect SWA managers inside eagle attention groups at coordinator
init time and disable the cache block mask for them.

Signed-off-by: Alex Bilichenko <abilichenko@gmail.com>
(cherry picked from commit b90c495)
Signed-off-by: jasl <jasl9187@hotmail.com>
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
…che hit (vllm-project#42258)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
…che hit (vllm-project#42258)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
jasl pushed a commit to jasl/vllm that referenced this pull request May 20, 2026
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip
caching SWA blocks that can never serve a prefix-cache hit.  When
Eagle/MTP speculative decoding is active the mask is too aggressive —
it skips blocks that eagle's modified lookup actually needs, resulting
in 0% prefix cache hit rate.

Eagle changes the SWA hit logic in two ways:
  1. sliding_window_contiguous_blocks += 1 (needs one extra block)
  2. post_pop_blocks = i (instead of i+1), shifting alignment

Fix: detect SWA managers inside eagle attention groups at coordinator
init time and disable the cache block mask for them.

Signed-off-by: Alex Bilichenko <abilichenko@gmail.com>
(cherry picked from commit b90c495)
Signed-off-by: jasl <jasl9187@hotmail.com>
h1t35h pushed a commit to h1t35h/vllm that referenced this pull request May 21, 2026
…che hit (vllm-project#42258)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
jasl pushed a commit to jasl/vllm that referenced this pull request May 22, 2026
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip
caching SWA blocks that can never serve a prefix-cache hit.  When
Eagle/MTP speculative decoding is active the mask is too aggressive —
it skips blocks that eagle's modified lookup actually needs, resulting
in 0% prefix cache hit rate.

Eagle changes the SWA hit logic in two ways:
  1. sliding_window_contiguous_blocks += 1 (needs one extra block)
  2. post_pop_blocks = i (instead of i+1), shifting alignment

Fix: detect SWA managers inside eagle attention groups at coordinator
init time and disable the cache block mask for them.

Signed-off-by: Alex Bilichenko <abilichenko@gmail.com>
(cherry picked from commit b90c495)
Signed-off-by: jasl <jasl9187@hotmail.com>
jasl pushed a commit to jasl/vllm that referenced this pull request May 22, 2026
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip
caching SWA blocks that can never serve a prefix-cache hit.  When
Eagle/MTP speculative decoding is active the mask is too aggressive —
it skips blocks that eagle's modified lookup actually needs, resulting
in 0% prefix cache hit rate.

Eagle changes the SWA hit logic in two ways:
  1. sliding_window_contiguous_blocks += 1 (needs one extra block)
  2. post_pop_blocks = i (instead of i+1), shifting alignment

Fix: detect SWA managers inside eagle attention groups at coordinator
init time and disable the cache block mask for them.

Signed-off-by: Alex Bilichenko <abilichenko@gmail.com>
(cherry picked from commit b90c495)
Signed-off-by: jasl <jasl9187@hotmail.com>
DoradusResearch pushed a commit to DoradusResearch/vllm that referenced this pull request May 23, 2026
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip
caching SWA blocks that can never serve a prefix-cache hit.  When
Eagle/MTP speculative decoding is active the mask is too aggressive —
it skips blocks that eagle's modified lookup actually needs, resulting
in 0% prefix cache hit rate.

Eagle changes the SWA hit logic in two ways:
  1. sliding_window_contiguous_blocks += 1 (needs one extra block)
  2. post_pop_blocks = i (instead of i+1), shifting alignment

Fix: detect SWA managers inside eagle attention groups at coordinator
init time and disable the cache block mask for them.

Signed-off-by: Alex Bilichenko <abilichenko@gmail.com>
(cherry picked from commit b90c495)
Signed-off-by: jasl <jasl9187@hotmail.com>
mlow pushed a commit to mlow/vllm that referenced this pull request May 27, 2026
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip
caching SWA blocks that can never serve a prefix-cache hit.  When
Eagle/MTP speculative decoding is active the mask is too aggressive —
it skips blocks that eagle's modified lookup actually needs, resulting
in 0% prefix cache hit rate.

Eagle changes the SWA hit logic in two ways:
  1. sliding_window_contiguous_blocks += 1 (needs one extra block)
  2. post_pop_blocks = i (instead of i+1), shifting alignment

Fix: detect SWA managers inside eagle attention groups at coordinator
init time and disable the cache block mask for them.

Signed-off-by: Alex Bilichenko <abilichenko@gmail.com>
(cherry picked from commit b90c495)
Signed-off-by: jasl <jasl9187@hotmail.com>
mlow pushed a commit to mlow/vllm that referenced this pull request May 27, 2026
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip
caching SWA blocks that can never serve a prefix-cache hit.  When
Eagle/MTP speculative decoding is active the mask is too aggressive —
it skips blocks that eagle's modified lookup actually needs, resulting
in 0% prefix cache hit rate.

Eagle changes the SWA hit logic in two ways:
  1. sliding_window_contiguous_blocks += 1 (needs one extra block)
  2. post_pop_blocks = i (instead of i+1), shifting alignment

Fix: detect SWA managers inside eagle attention groups at coordinator
init time and disable the cache block mask for them.

Signed-off-by: Alex Bilichenko <abilichenko@gmail.com>
(cherry picked from commit b90c495)
Signed-off-by: jasl <jasl9187@hotmail.com>
jasl pushed a commit to jasl/vllm that referenced this pull request May 27, 2026
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip
caching SWA blocks that can never serve a prefix-cache hit.  When
Eagle/MTP speculative decoding is active the mask is too aggressive —
it skips blocks that eagle's modified lookup actually needs, resulting
in 0% prefix cache hit rate.

Eagle changes the SWA hit logic in two ways:
  1. sliding_window_contiguous_blocks += 1 (needs one extra block)
  2. post_pop_blocks = i (instead of i+1), shifting alignment

Fix: detect SWA managers inside eagle attention groups at coordinator
init time and disable the cache block mask for them.

Signed-off-by: Alex Bilichenko <abilichenko@gmail.com>
(cherry picked from commit b90c495)
Signed-off-by: jasl <jasl9187@hotmail.com>
jasl pushed a commit to jasl/vllm that referenced this pull request May 28, 2026
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip
caching SWA blocks that can never serve a prefix-cache hit.  When
Eagle/MTP speculative decoding is active the mask is too aggressive —
it skips blocks that eagle's modified lookup actually needs, resulting
in 0% prefix cache hit rate.

Eagle changes the SWA hit logic in two ways:
  1. sliding_window_contiguous_blocks += 1 (needs one extra block)
  2. post_pop_blocks = i (instead of i+1), shifting alignment

Fix: detect SWA managers inside eagle attention groups at coordinator
init time and disable the cache block mask for them.

Signed-off-by: Alex Bilichenko <abilichenko@gmail.com>
(cherry picked from commit b90c495)
Signed-off-by: jasl <jasl9187@hotmail.com>
Liuweixiong0118 pushed a commit to Liuweixiong0118/vllm that referenced this pull request Jun 1, 2026
…che hit (vllm-project#42258)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Liuweixiong0118 <lwx34158427@gmail.com>
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
…che hit (vllm-project#42258)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
andakai pushed a commit to andakai/vllm that referenced this pull request Jun 4, 2026
…che hit (vllm-project#42258)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants