Skip to content

[Bugfix] Cache the EAGLE/MTP lookahead block in the SWA prefix-cache mask#44082

Merged
ywang96 merged 2 commits into
vllm-project:mainfrom
ivanium:fix/eagle-cache-mask
Jun 2, 2026
Merged

[Bugfix] Cache the EAGLE/MTP lookahead block in the SWA prefix-cache mask#44082
ywang96 merged 2 commits into
vllm-project:mainfrom
ivanium:fix/eagle-cache-mask

Conversation

@ivanium
Copy link
Copy Markdown
Collaborator

@ivanium ivanium commented May 30, 2026

Summary

PR #42258 added SlidingWindowManager._cache_block_mask() to skip caching SWA blocks that can never serve a prefix-cache hit. When EAGLE/MTP speculative decoding is active and the cache-hit alignment (the LCM of per-group block sizes) is larger than the SWA window, that mask is too aggressive: EAGLE's lookup needs tail + 1 contiguous cached blocks, and the extra +1 block lives at the first position past each aligned segment boundary — exactly a position the mask skipped. The result is that EAGLE + SWA finds no prefix-cache hit.

This PR fixes the mask so the one extra lookahead block EAGLE requires is cached, while keeping the rest of the #42258 optimization intact.

Note

MooncakeStoreConnector also needs a similar fix. Will be addressed in a follow-up PR.
SimpleCPUOffloadingConnector is not affected because it reuses KVCacheCoordinator logic directly.

What changed

  • SlidingWindowManager: factor the "contiguous blocks needed for a hit" calculation into _contiguous_blocks_for_hit(window, block_size, use_eagle), shared by both the cache-hit lookup and the cache mask so they stay in sync.
  • Rework reachable_block_mask (formerly _cache_block_mask) to mark a block reachable iff it falls in the need-wide run ending at an aligned boundary's right edge, applying the EAGLE shift=1 when EAGLE is active. This keeps the EAGLE lookahead block (one past the boundary) eligible to be cached.
  • HybridKVCacheCoordinator.cache_blocks: when a manager is an EAGLE group, extend num_tokens_to_cache by one block past the aligned boundary so the lookahead block is actually written into the prefix-cache hash map.
  • Introduce a small SpecGroup NamedTuple to carry the per-spec-group use_eagle bit explicitly, and propagate it to each SingleTypeKVCacheManager (grouped SWA siblings share one cache-hit lookup, so the EAGLE drop is decided per spec group). This replaces the ad-hoc eagle_attn_group_indices set.
  • Rename the use_eagle parameter of find_longest_cache_hit to drop_eagle to disambiguate "this group uses EAGLE" (a manager attribute) from "drop the last matched block on this lookup pass" (a per-pass decision in the hybrid fixed-point loop).

Why this is not duplicating PR #42784

PR #42784 fixes the same underlying bug but by disabling the SWA cache mask entirely whenever EAGLE is active — which caches every SWA block and so gives up the memory-saving benefit that #42258 was added to provide.

This PR instead preserves the #42258 optimization and caches only the single additional lookahead block that EAGLE actually needs per aligned segment. The masking logic and the cache-hit lookup are driven from one shared helper, so they cannot drift apart. It also covers the grouped-SWA-siblings case (multiple SWA groups sharing one spec, one of which is an EAGLE/MTP group) and the block_size != alignment_tokens (Gemma-style different-page-size) path.

I am happy to consolidate with #42784 if maintainers prefer one approach.

Test plan

.venv/bin/python -m pytest \
  tests/v1/core/test_prefix_caching.py \
  tests/v1/core/test_single_type_kv_cache_manager.py -q
# 72 passed

New regression tests added:

  • test_eagle_swa_alignment_caches_extra_block — EAGLE + SWA with sliding_window <= alignment finds a non-zero cache hit.
  • test_eagle_swa_boundary_caches_post_boundary_block — the first block past an alignment boundary (the EAGLE lookahead block) is cached.
  • test_eagle_grouped_swa_siblings_use_same_cache_mask — grouped SWA siblings cache the lookahead block together.

Lint: pre-commit run --files <changed files> — ruff check, ruff format, and mypy all pass.

Notes

This change was developed with AI assistance (Claude Code). The submitter has reviewed every changed line and run the tests above.

@ivanium
Copy link
Copy Markdown
Collaborator Author

ivanium commented May 30, 2026

@mergify mergify Bot added the kv-connector label May 30, 2026
@ivanium ivanium marked this pull request as ready for review May 30, 2026 23:02
@ivanium ivanium force-pushed the fix/eagle-cache-mask branch 2 times, most recently from ad183e7 to 8a1248b Compare June 1, 2026 06:24
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
drop_eagle: Whether to drop the last matched block for EAGLE/MTP.
Copy link
Copy Markdown
Contributor

@wzhao18 wzhao18 Jun 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find variable name drop_eagle less intuitive to understand than use_eagle + a comment suggesting eagle requires dropping the last matched block. Is there distinction between the two names?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I keep use_eagle as a member variable of SingleTypeKVCacheManager, since it is an instance property. In contrast, drop_eagle is specific to this class method and can be set to False even for a manager with EAGLE layers. For example, this happens during the convergence loop, where the drop has already been applied.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks for the explanation.

nit: Maybe it would be better to rename it to something like drop_eagle_block and add a comment saying this could be false even when for kv manager with eager layers.

Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ivanium

@ivanium ivanium added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 1, 2026
@ivanium ivanium force-pushed the fix/eagle-cache-mask branch from 8a1248b to 66ad620 Compare June 1, 2026 22:53
ivanium added 2 commits June 2, 2026 08:33
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
@ivanium ivanium force-pushed the fix/eagle-cache-mask branch from 66ad620 to 7056962 Compare June 2, 2026 08:46
@ywang96 ywang96 merged commit e9e08c4 into vllm-project:main Jun 2, 2026
71 checks passed
@ivanium ivanium deleted the fix/eagle-cache-mask branch June 2, 2026 19:21
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
…mask (vllm-project#44082)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request Jun 4, 2026
…mask (vllm-project#44082)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
andakai pushed a commit to andakai/vllm that referenced this pull request Jun 4, 2026
…mask (vllm-project#44082)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
JisoLya pushed a commit to JisoLya/vllm that referenced this pull request Jun 5, 2026
…mask (vllm-project#44082)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: JisoLya <523420504@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants