Skip to content

[Core][Refactor]: thread scheduler_block_size into KVCacheManager and KVCacheCoordinator#44165

Merged
ywang96 merged 3 commits into
vllm-project:mainfrom
ivanium:refactor/scheduler-block-size
Jun 2, 2026
Merged

[Core][Refactor]: thread scheduler_block_size into KVCacheManager and KVCacheCoordinator#44165
ywang96 merged 3 commits into
vllm-project:mainfrom
ivanium:refactor/scheduler-block-size

Conversation

@ivanium

@ivanium ivanium commented Jun 1, 2026

Copy link
Copy Markdown
Collaborator

Purpose

This is a small, behavior-preserving refactor that threads an explicit scheduler_block_size through KVCacheManagerKVCacheCoordinatorSingleTypeKVCacheManager, instead of having HybridKVCacheCoordinator recompute the LCM of group block sizes internally.

Today the scheduler already resolves the scheduling-alignment granularity via resolve_kv_cache_block_sizes (returned as scheduler_block_size, the LCM of all group block sizes for the multi-group non-context-parallel case) and stores it as Scheduler.block_size. Separately, HybridKVCacheCoordinator independently recomputed the same quantity as self.lcm_block_size = lcm(*block_sizes). This PR removes that duplicate computation and instead passes the already-resolved value down, making the alignment invariant a single explicit input rather than a value derived in two places.

This is a preliminary step in prep for refactoring/merging #43447 (selective prefix-cache retention for sliding-window KV cache), which needs the scheduling block size available at the manager/coordinator level. Landing the plumbing on its own keeps that follow-up focused on the retention logic.

Behavioral equivalence

  • HybridKVCacheCoordinator.cache_blocks and find_longest_cache_hit now align on self.scheduler_block_size instead of self.lcm_block_size. For the only configuration that reaches HybridKVCacheCoordinator (multiple KV cache groups, context parallelism disabled), resolve_kv_cache_block_sizes returns exactly math.lcm(*group_block_sizes) — identical to the old internal computation over the same set of groups. Hybrid groups + context parallelism is rejected upstream in resolve_kv_cache_block_sizes, so there is no configuration where the two values could diverge.
  • self.scheduler_block_size is also stored on SingleTypeKVCacheManager. It is not consumed yet in this PR; it is the plumbing that [Prefix Caching] DeepSeekv4 - Support selective prefix-cache retention for sliding-window KV cache #43447 builds on.

All get_kv_cache_coordinator / KVCacheManager constructor sites are updated (scheduler and simple_kv_offload). The Mooncake store path uses its own coordinator and already carries its own scheduler_block_size; it is untouched here.

Why this is not duplicating an existing PR

A search of open PRs (scheduler_block_size, block-size threading into KVCacheManager) returns no overlap. #36317 ("Adjust alignment block size according attn supported kernel sizes") changes how the alignment block size is chosen per attention kernel — a different concern from threading the already-resolved value through the manager/coordinator. This PR adds no new behavior and changes no defaults.

Test Plan

.venv/bin/python -m pytest \
  tests/v1/core/test_prefix_caching.py \
  tests/v1/core/test_single_type_kv_cache_manager.py \
  tests/v1/core/test_kv_cache_utils.py -v

Tests are updated to pass scheduler_block_size. test_prefix_caching.py adds a small make_kv_cache_manager helper that derives scheduler_block_size from the config (LCM of group block sizes), mirroring resolve_kv_cache_block_sizes for the non-context-parallel path so call sites don't repeat it.

Test Result

126 passed in 35.08s

pre-commit run (ruff, mypy) passes on all changed files.


AI assistance (Claude Code) was used while preparing this change. The submitter has reviewed every changed line and run the tests above.

@mergify mergify Bot added the v1 label Jun 1, 2026
@ivanium ivanium force-pushed the refactor/scheduler-block-size branch from ccc75d7 to c221114 Compare June 1, 2026 05:15
@ivanium ivanium marked this pull request as ready for review June 1, 2026 05:15
@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 1, 2026
@ivanium ivanium changed the title refactor: thread scheduler_block_size into KVCacheManager and KVCacheCoordinator [Core][Refactor]: thread scheduler_block_size into KVCacheManager and KVCacheCoordinator Jun 1, 2026

@njhill njhill left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ivanium

Comment on lines +49 to +50
# The scheduling granularity (LCM of all group block sizes), must be a multiple
# of the hash_block_size and the block size of each group.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could/should we add an assert here for this?

ivanium added 3 commits June 1, 2026 22:25
…heCoordinator

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
@ivanium ivanium force-pushed the refactor/scheduler-block-size branch from ece24fc to 8c8a6e5 Compare June 1, 2026 22:26
@mergify mergify Bot added the kv-connector label Jun 1, 2026
@ywang96 ywang96 merged commit 7c37096 into vllm-project:main Jun 2, 2026
64 checks passed
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
…nd KVCacheCoordinator (vllm-project#44165)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request Jun 4, 2026
…nd KVCacheCoordinator (vllm-project#44165)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
andakai pushed a commit to andakai/vllm that referenced this pull request Jun 4, 2026
…nd KVCacheCoordinator (vllm-project#44165)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
JisoLya pushed a commit to JisoLya/vllm that referenced this pull request Jun 5, 2026
…nd KVCacheCoordinator (vllm-project#44165)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: JisoLya <523420504@qq.com>
knight0528 pushed a commit to knight0528/vllm that referenced this pull request Jun 8, 2026
…nd KVCacheCoordinator (vllm-project#44165)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
waqahmed-amd-fi pushed a commit to waqahmed-amd-fi/vllm that referenced this pull request Jun 10, 2026
…nd KVCacheCoordinator (vllm-project#44165)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Waqar Ahmed <waqar.ahmed@amd.com>
@ivanium ivanium deleted the refactor/scheduler-block-size branch June 13, 2026 23:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants