From 134689373e8cd356f173bb7c986b9371fb3bc4a7 Mon Sep 17 00:00:00 2001 From: Dao Le Date: Sun, 26 Apr 2026 22:04:20 +0000 Subject: [PATCH 1/4] [Bugfix] Cap SWA/chunked-local runtime admission to startup bound `SlidingWindowSpec.max_memory_usage_bytes` and `ChunkedLocalAttentionSpec.max_memory_usage_bytes` size the pool at startup with a recycling-aware bound (one window/chunk window + `max_num_batched_tokens`, plus a 1-block alignment slack for SWA). At runtime, however, `SingleTypeKVCacheManager.get_num_blocks_to_allocate` returns `cdiv(num_tokens, block_size)` for a fresh request, since `get_num_skipped_tokens(0) == 0`. That over-counts: chunked prefill invokes `remove_skipped_blocks` between chunks, which swaps out-of-window blocks for the null block and returns their slots to the pool, so the per-request real-held block count plateaus at the recycling-aware bound. The mismatch deadlocks long prompts on hybrid full+SWA models when the pool is sized at the startup minimum -- the admission gate rejects what startup was sized to admit (issue #39734). Fix: - Hoist the recycling-aware bound onto the spec as `max_admission_blocks_per_request`, and have `max_memory_usage_bytes` call it so the startup pool sizer and the runtime admission gate share one source of truth (drift would re-introduce #39734 or, worse, mid-prefill OOM). - Plumb `max_num_batched_tokens` through `KVCacheManager` -> `KVCacheCoordinator` -> `get_manager_for_kv_cache_spec`. `KVCacheManager` defaults the parameter to `max_model_len` (a no-op cap) so non-scheduler call sites keep their prior behavior; the scheduler and the simple CPU offload scheduler pass the real value. - `SlidingWindowManager` and `ChunkedLocalAttentionManager` cap demand at the same per-request bound in `get_num_blocks_to_allocate`. The invariant remains `sum(reservations) <= pool` and per-request peak <= reservation (held by `remove_skipped_blocks`), so total real-held <= pool. Tests: - `test_can_fit_full_sequence_swa_cap_admits_long_prompt`: hybrid full+SWA with the pool at the startup minimum admits a prompt longer than the SWA window + chunk. - `test_can_fit_full_sequence_full_attention_still_gates_oversized`: the cap doesn't loosen the full-attention gate. Co-authored-by: Claude Signed-off-by: Dao Le --- tests/v1/core/test_prefix_caching.py | 108 +++++++++++++++++ .../core/test_single_type_kv_cache_manager.py | 3 + vllm/v1/core/kv_cache_coordinator.py | 13 ++ vllm/v1/core/kv_cache_manager.py | 7 ++ vllm/v1/core/sched/scheduler.py | 1 + vllm/v1/core/single_type_kv_cache_manager.py | 113 +++++++++++++++++- vllm/v1/kv_cache_interface.py | 73 +++++++---- vllm/v1/simple_kv_offload/manager.py | 3 + 8 files changed, 294 insertions(+), 27 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 22220599f158..8863a835ad42 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -2513,3 +2513,111 @@ def test_block_lookup_cache_multi_blocks_per_key(): assert cache.pop(key1, 11) is block11 assert cache.get_one_block(key1) is None assert cache.pop(key1, 12) is None + + +def test_can_fit_full_sequence_swa_cap_admits_long_prompt(): + """Hybrid full+SWA model with a pool sized at the startup minimum should + admit a prompt longer than the SWA cap, because SlidingWindowManager + recycles blocks during chunked prefill (issue #39734).""" + block_size = 16 + sliding_window = 4 * block_size # 64 tokens + max_num_batched_tokens = 8 * block_size # 128 tokens + max_model_len = 64 * block_size # 1024 tokens — much larger than the SWA cap + # Startup pool sizing: full demands cdiv(max_model_len, bs) = 64 blocks, + # SWA demands cdiv(SW-1+max_batched, bs) + 1 = cdiv(191, 16) + 1 = 13. + # Pool minimum = 64 + 13 = 77; +1 for the null block. + num_blocks = 64 + 13 + 1 + + config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + FullAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), + ), + KVCacheGroupSpec( + ["layer_swa"], + SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=sliding_window, + ), + ), + ], + ) + + manager = KVCacheManager( + config, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + enable_caching=True, + hash_block_size=block_size, + ) + + # A prompt that is shorter than max_model_len but longer than SW + chunk: + # cdiv(prompt_len, bs) = 32 blocks. Without the cap, admission would + # demand 32 (full) + 32 (SWA) = 64 blocks. With the cap, SWA contributes + # only 13, so total = 32 + 13 = 45 ≤ pool size. + prompt_len = 32 * block_size + req = make_request("long", list(range(prompt_len)), block_size, sha256) + + assert manager.can_fit_full_sequence(req) + + +def test_can_fit_full_sequence_full_attention_still_gates_oversized(): + """The cap only loosens the SWA group; a prompt that exceeds the + full-attention pool capacity must still be rejected.""" + block_size = 16 + sliding_window = 4 * block_size + max_num_batched_tokens = 8 * block_size + max_model_len = 64 * block_size + # Provide a tiny pool — even a small prompt should be rejected. + num_blocks = 5 + + config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + FullAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), + ), + KVCacheGroupSpec( + ["layer_swa"], + SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=sliding_window, + ), + ), + ], + ) + + manager = KVCacheManager( + config, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + enable_caching=True, + hash_block_size=block_size, + ) + + # 16 blocks of full attention demand alone exceeds the 5-block pool. + prompt_len = 16 * block_size + req = make_request("oversized", list(range(prompt_len)), block_size, sha256) + + assert not manager.can_fit_full_sequence(req) diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index b05040ebe2a6..08fda7593e28 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -22,11 +22,13 @@ def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True): + # Tests don't exercise admission gating; pass a large cap that is a no-op. return SlidingWindowManager( sliding_window_spec, block_pool=block_pool, enable_caching=enable_caching, kv_cache_group_id=0, + max_admission_blocks_per_request=10**9, ) @@ -38,6 +40,7 @@ def get_chunked_local_attention_manager( block_pool=block_pool, enable_caching=enable_caching, kv_cache_group_id=0, + max_admission_blocks_per_request=10**9, ) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index eaa95dfe49f7..40ab175e7598 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -34,6 +34,7 @@ def __init__( self, kv_cache_config: KVCacheConfig, max_model_len: int, + max_num_batched_tokens: int, use_eagle: bool, enable_caching: bool, enable_kv_cache_events: bool, @@ -59,6 +60,8 @@ def __init__( self.single_type_managers = tuple( get_manager_for_kv_cache_spec( kv_cache_spec=kv_cache_group.kv_cache_spec, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, block_pool=self.block_pool, enable_caching=enable_caching, kv_cache_group_id=i, @@ -265,6 +268,7 @@ def __init__( self, kv_cache_config: KVCacheConfig, max_model_len: int, + max_num_batched_tokens: int, use_eagle: bool, enable_kv_cache_events: bool, dcp_world_size: int, @@ -275,6 +279,7 @@ def __init__( super().__init__( kv_cache_config, max_model_len, + max_num_batched_tokens, use_eagle, False, enable_kv_cache_events, @@ -310,6 +315,7 @@ def __init__( self, kv_cache_config: KVCacheConfig, max_model_len: int, + max_num_batched_tokens: int, use_eagle: bool, enable_caching: bool, enable_kv_cache_events: bool, @@ -321,6 +327,7 @@ def __init__( super().__init__( kv_cache_config, max_model_len, + max_num_batched_tokens, use_eagle, enable_caching, enable_kv_cache_events, @@ -375,6 +382,7 @@ def __init__( self, kv_cache_config: KVCacheConfig, max_model_len: int, + max_num_batched_tokens: int, use_eagle: bool, enable_caching: bool, enable_kv_cache_events: bool, @@ -386,6 +394,7 @@ def __init__( super().__init__( kv_cache_config, max_model_len, + max_num_batched_tokens, use_eagle, enable_caching, enable_kv_cache_events, @@ -547,6 +556,7 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: def get_kv_cache_coordinator( kv_cache_config: KVCacheConfig, max_model_len: int, + max_num_batched_tokens: int, use_eagle: bool, enable_caching: bool, enable_kv_cache_events: bool, @@ -559,6 +569,7 @@ def get_kv_cache_coordinator( return KVCacheCoordinatorNoPrefixCache( kv_cache_config, max_model_len, + max_num_batched_tokens, use_eagle, enable_kv_cache_events, dcp_world_size=dcp_world_size, @@ -570,6 +581,7 @@ def get_kv_cache_coordinator( return UnitaryKVCacheCoordinator( kv_cache_config, max_model_len, + max_num_batched_tokens, use_eagle, enable_caching, enable_kv_cache_events, @@ -581,6 +593,7 @@ def get_kv_cache_coordinator( return HybridKVCacheCoordinator( kv_cache_config, max_model_len, + max_num_batched_tokens, use_eagle, enable_caching, enable_kv_cache_events, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index dcec5e05bf97..83aa26bd96f0 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -109,6 +109,7 @@ def __init__( kv_cache_config: KVCacheConfig, max_model_len: int, hash_block_size: int, + max_num_batched_tokens: int | None = None, enable_caching: bool = True, use_eagle: bool = False, log_stats: bool = False, @@ -118,6 +119,11 @@ def __init__( metrics_collector: KVCacheMetricsCollector | None = None, ) -> None: self.max_model_len = max_model_len + # When unset, fall back to `max_model_len` so the recycling-aware cap + # collapses to the prior (uncapped) admission behavior. The scheduler + # always supplies the real value at runtime. + if max_num_batched_tokens is None: + max_num_batched_tokens = max_model_len self.enable_caching = enable_caching self.use_eagle = use_eagle @@ -131,6 +137,7 @@ def __init__( self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, + max_num_batched_tokens=max_num_batched_tokens, use_eagle=self.use_eagle, enable_caching=self.enable_caching, enable_kv_cache_events=enable_kv_cache_events, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 40b5899f0457..f267f98bd8fd 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -225,6 +225,7 @@ def __init__( self.kv_cache_manager = KVCacheManager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, + max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, enable_caching=self.cache_config.enable_prefix_caching, use_eagle=self.use_eagle, log_stats=self.log_stats, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 30061462008f..9c3cc14c65c7 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -479,9 +479,63 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class SlidingWindowManager(SingleTypeKVCacheManager): - def __init__(self, kv_cache_spec: SlidingWindowSpec, **kwargs) -> None: + def __init__( + self, + kv_cache_spec: SlidingWindowSpec, + *, + max_admission_blocks_per_request: int, + **kwargs, + ) -> None: super().__init__(kv_cache_spec, **kwargs) self.sliding_window = kv_cache_spec.sliding_window + # Recycling-aware admission cap: matches the bound used to size the + # pool at startup in `SlidingWindowSpec.max_memory_usage_bytes`. Per + # request, `remove_skipped_blocks` keeps the real-held block count + # from exceeding this; the admission gate composes per-request caps + # so total reservations never exceed the pool. + self._max_admission_blocks_per_request = max_admission_blocks_per_request + + def get_num_blocks_to_allocate( + self, + request_id: str, + num_tokens: int, + new_computed_blocks: Sequence[KVCacheBlock], + total_computed_tokens: int, + num_tokens_main_model: int, + ) -> int: + """Return the admission *reservation* (not lifetime block touches). + + For a fresh request, the base implementation would ask for + ``cdiv(num_tokens, block_size)`` blocks because + ``get_num_skipped_tokens(0) == 0``. That over-counts: chunked prefill + invokes :meth:`remove_skipped_blocks` between chunks, which swaps + out-of-window blocks for ``null_block`` and returns their slots to + the pool, so per-request real-held blocks plateau at roughly one + window's worth. + + We therefore cap demand at the same recycling-aware bound that + :meth:`SlidingWindowSpec.max_memory_usage_bytes + ` + used to size the pool at startup. The two formulas MUST stay in sync + (drift re-introduces the deadlock from issue #39734 or, worse, + mid-prefill OOM); both derive from + :meth:`SlidingWindowSpec.max_admission_blocks_per_request + `. + + Safety: per-request peak ≤ reservation (guaranteed by + :meth:`remove_skipped_blocks`), and the admission gate ensures + ``sum(reservations) ≤ pool``, so ``sum( peak_real_held) ≤ pool``. + """ + capped_num_tokens = min( + num_tokens, self._max_admission_blocks_per_request * self.block_size + ) + return super().get_num_blocks_to_allocate( + request_id, + capped_num_tokens, + new_computed_blocks, + total_computed_tokens, + num_tokens_main_model, + ) @classmethod def find_longest_cache_hit( @@ -618,9 +672,50 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): - def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, **kwargs) -> None: + def __init__( + self, + kv_cache_spec: ChunkedLocalAttentionSpec, + *, + max_admission_blocks_per_request: int, + **kwargs, + ) -> None: super().__init__(kv_cache_spec, **kwargs) self.attention_chunk_size = kv_cache_spec.attention_chunk_size + # Recycling-aware admission cap, mirroring the startup pool sizer in + # `ChunkedLocalAttentionSpec.max_memory_usage_bytes`. See + # `SlidingWindowManager` for the safety argument. + self._max_admission_blocks_per_request = max_admission_blocks_per_request + + def get_num_blocks_to_allocate( + self, + request_id: str, + num_tokens: int, + new_computed_blocks: Sequence[KVCacheBlock], + total_computed_tokens: int, + num_tokens_main_model: int, + ) -> int: + """Return the admission *reservation* (not lifetime block touches). + + Caps demand at the recycling-aware bound mirroring + :meth:`ChunkedLocalAttentionSpec.max_memory_usage_bytes + `, + which the startup pool sizer uses. Both formulas derive from + :meth:`ChunkedLocalAttentionSpec.max_admission_blocks_per_request + ` + and MUST stay in sync. See :meth:`SlidingWindowManager.get_num_blocks_to_allocate` + for the recycling-vs-reservation safety argument; the same invariant + holds here via :meth:`remove_skipped_blocks`. + """ + capped_num_tokens = min( + num_tokens, self._max_admission_blocks_per_request * self.block_size + ) + return super().get_num_blocks_to_allocate( + request_id, + capped_num_tokens, + new_computed_blocks, + total_computed_tokens, + num_tokens_main_model, + ) @classmethod def find_longest_cache_hit( @@ -1126,8 +1221,20 @@ def __init__( def get_manager_for_kv_cache_spec( - kv_cache_spec: KVCacheSpec, **kwargs + kv_cache_spec: KVCacheSpec, + max_num_batched_tokens: int, + max_model_len: int, + **kwargs, ) -> SingleTypeKVCacheManager: manager_class = spec_manager_map[type(kv_cache_spec)] + # SlidingWindow / ChunkedLocalAttention managers recycle blocks across + # chunks; the runtime admission cap must match the recycling-aware bound + # the startup pool sizer uses (single source of truth: the spec method). + if isinstance(kv_cache_spec, (SlidingWindowSpec, ChunkedLocalAttentionSpec)): + kwargs["max_admission_blocks_per_request"] = ( + kv_cache_spec.max_admission_blocks_per_request( + max_num_batched_tokens, max_model_len + ) + ) manager = manager_class(kv_cache_spec, **kwargs) return manager diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index bc8422d4f4b5..8b52861f4f36 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -338,19 +338,29 @@ def merge(cls, specs: list[Self]) -> Self: class ChunkedLocalAttentionSpec(AttentionSpec): attention_chunk_size: int - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - max_model_len = vllm_config.model_config.max_model_len - max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens - - # During chunked prefill, we allocate KV cache for at most - # `self.attention_chunk_size` computed tokens plus the newly scheduled - # tokens. And we won't allocate KV cache for more than `max_model_len` - # tokens. + def max_admission_blocks_per_request( + self, max_num_batched_tokens: int, max_model_len: int + ) -> int: + """Per-request admission cap, in blocks. + + Matches the recycling-aware bound used to size the pool at startup + (see `max_memory_usage_bytes`). Used by the runtime admission gate so + that requests admitted by startup can also be admitted at runtime. + """ + # During chunked prefill, we hold KV for at most one chunk window. num_tokens = min( self.attention_chunk_size + max_num_batched_tokens, max_model_len ) + return cdiv(num_tokens, self.block_size) - return cdiv(num_tokens, self.block_size) * self.page_size_bytes + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + return ( + self.max_admission_blocks_per_request( + vllm_config.scheduler_config.max_num_batched_tokens, + vllm_config.model_config.max_model_len, + ) + * self.page_size_bytes + ) @dataclass(frozen=True, kw_only=True) @@ -371,26 +381,41 @@ def real_page_size_bytes(self) -> int: * get_dtype_size(self.dtype) ) - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - assert vllm_config.parallel_config.decode_context_parallel_size == 1, ( - "DCP not support sliding window." - ) - max_model_len = vllm_config.model_config.max_model_len - max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + def max_admission_blocks_per_request( + self, max_num_batched_tokens: int, max_model_len: int + ) -> int: + """Per-request admission cap, in blocks. + + Matches the recycling-aware bound used to size the pool at startup + (see `max_memory_usage_bytes`). Used by the runtime admission gate so + that requests admitted by startup can also be admitted at runtime. - # During chunked prefill, we allocate KV cache for the last - # `self.sliding_window-1` computed tokens plus the newly scheduled - # tokens. And we won't allocate KV cache for more than `max_model_len` - # tokens. + Safety: `SlidingWindowManager.remove_skipped_blocks` is invoked from + `allocate_slots` before each chunk's `get_num_blocks_to_allocate`, so + the per-request real-held block count plateaus at this bound. + """ + # During chunked prefill, we hold KV for the last `sliding_window-1` + # computed tokens plus the newly scheduled tokens, and never more + # than `max_model_len`. num_tokens = min( self.sliding_window - 1 + max_num_batched_tokens, max_model_len ) + # +1 because the sliding window may not start from the beginning of + # the block. E.g. block size 4 and num_token 4 needs two blocks + # [XXCD][EF] to store the 6-token window [CDEF]. + return cdiv(num_tokens, self.block_size) + 1 - # +1 here because the sliding window may not start from the beginning - # of the block. For example, if the block size is 4 and num_token - # is 4, we need two blocks [XXCD] [EF] to store the sliding - # window [CDEF] of 6 tokens. - return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + assert vllm_config.parallel_config.decode_context_parallel_size == 1, ( + "DCP not support sliding window." + ) + return ( + self.max_admission_blocks_per_request( + vllm_config.scheduler_config.max_num_batched_tokens, + vllm_config.model_config.max_model_len, + ) + * self.page_size_bytes + ) @dataclass(frozen=True) diff --git a/vllm/v1/simple_kv_offload/manager.py b/vllm/v1/simple_kv_offload/manager.py index 5eedc07f717e..846526e5bee4 100644 --- a/vllm/v1/simple_kv_offload/manager.py +++ b/vllm/v1/simple_kv_offload/manager.py @@ -110,6 +110,9 @@ def __init__( self.cpu_coordinator: KVCacheCoordinator = get_kv_cache_coordinator( kv_cache_config=self.cpu_kv_cache_config, max_model_len=vllm_config.model_config.max_model_len, + max_num_batched_tokens=( + vllm_config.scheduler_config.max_num_batched_tokens + ), use_eagle=False, enable_caching=True, enable_kv_cache_events=self.enable_kv_cache_events, From ff6ed91c18381a0cc9a2f833475557ebe0040b42 Mon Sep 17 00:00:00 2001 From: Dao Le Date: Sun, 26 Apr 2026 23:57:31 +0000 Subject: [PATCH 2/4] Update comments Signed-off-by: Dao Le --- vllm/v1/core/single_type_kv_cache_manager.py | 48 +++++++------------- vllm/v1/kv_cache_interface.py | 18 ++++---- 2 files changed, 25 insertions(+), 41 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 9c3cc14c65c7..3e6b156cabf8 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -488,11 +488,6 @@ def __init__( ) -> None: super().__init__(kv_cache_spec, **kwargs) self.sliding_window = kv_cache_spec.sliding_window - # Recycling-aware admission cap: matches the bound used to size the - # pool at startup in `SlidingWindowSpec.max_memory_usage_bytes`. Per - # request, `remove_skipped_blocks` keeps the real-held block count - # from exceeding this; the admission gate composes per-request caps - # so total reservations never exceed the pool. self._max_admission_blocks_per_request = max_admission_blocks_per_request def get_num_blocks_to_allocate( @@ -505,26 +500,23 @@ def get_num_blocks_to_allocate( ) -> int: """Return the admission *reservation* (not lifetime block touches). - For a fresh request, the base implementation would ask for + For a fresh, request the base implementation asks for ``cdiv(num_tokens, block_size)`` blocks because ``get_num_skipped_tokens(0) == 0``. That over-counts: chunked prefill - invokes :meth:`remove_skipped_blocks` between chunks, which swaps + invokes `remove_skipped_blocks` between chunks, which swaps out-of-window blocks for ``null_block`` and returns their slots to the pool, so per-request real-held blocks plateau at roughly one window's worth. - We therefore cap demand at the same recycling-aware bound that - :meth:`SlidingWindowSpec.max_memory_usage_bytes - ` - used to size the pool at startup. The two formulas MUST stay in sync - (drift re-introduces the deadlock from issue #39734 or, worse, - mid-prefill OOM); both derive from - :meth:`SlidingWindowSpec.max_admission_blocks_per_request - `. - - Safety: per-request peak ≤ reservation (guaranteed by - :meth:`remove_skipped_blocks`), and the admission gate ensures - ``sum(reservations) ≤ pool``, so ``sum( peak_real_held) ≤ pool``. + We cap demand at the same recycling-aware bound that + `SlidingWindowSpec.max_memory_usage_bytes` used to size the pool at + startup; both call `SlidingWindowSpec.max_admission_blocks_per_request`, + the single source of truth -- drift would re-introduce the deadlock + from issue #39734 or, worse, mid-prefill OOM. + + Safety: per-request peak <= reservation (held by + `remove_skipped_blocks`), and the admission gate ensures + ``sum(reservations) <= pool``, so ``sum(peak_real_held) <= pool``. """ capped_num_tokens = min( num_tokens, self._max_admission_blocks_per_request * self.block_size @@ -681,9 +673,6 @@ def __init__( ) -> None: super().__init__(kv_cache_spec, **kwargs) self.attention_chunk_size = kv_cache_spec.attention_chunk_size - # Recycling-aware admission cap, mirroring the startup pool sizer in - # `ChunkedLocalAttentionSpec.max_memory_usage_bytes`. See - # `SlidingWindowManager` for the safety argument. self._max_admission_blocks_per_request = max_admission_blocks_per_request def get_num_blocks_to_allocate( @@ -696,15 +685,12 @@ def get_num_blocks_to_allocate( ) -> int: """Return the admission *reservation* (not lifetime block touches). - Caps demand at the recycling-aware bound mirroring - :meth:`ChunkedLocalAttentionSpec.max_memory_usage_bytes - `, - which the startup pool sizer uses. Both formulas derive from - :meth:`ChunkedLocalAttentionSpec.max_admission_blocks_per_request - ` - and MUST stay in sync. See :meth:`SlidingWindowManager.get_num_blocks_to_allocate` - for the recycling-vs-reservation safety argument; the same invariant - holds here via :meth:`remove_skipped_blocks`. + Caps demand at the recycling-aware bound from + `ChunkedLocalAttentionSpec.max_admission_blocks_per_request`, which + the startup pool sizer (`max_memory_usage_bytes`) also uses. See + `SlidingWindowManager.get_num_blocks_to_allocate` for the + recycling-vs-reservation safety argument; the same invariant holds + here via `remove_skipped_blocks`. """ capped_num_tokens = min( num_tokens, self._max_admission_blocks_per_request * self.block_size diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 8b52861f4f36..018570367b72 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -343,9 +343,9 @@ def max_admission_blocks_per_request( ) -> int: """Per-request admission cap, in blocks. - Matches the recycling-aware bound used to size the pool at startup - (see `max_memory_usage_bytes`). Used by the runtime admission gate so - that requests admitted by startup can also be admitted at runtime. + Single source of truth for both startup pool sizing + (`max_memory_usage_bytes`) and the runtime admission gate, so requests + admitted by startup can also be admitted at runtime. """ # During chunked prefill, we hold KV for at most one chunk window. num_tokens = min( @@ -386,13 +386,11 @@ def max_admission_blocks_per_request( ) -> int: """Per-request admission cap, in blocks. - Matches the recycling-aware bound used to size the pool at startup - (see `max_memory_usage_bytes`). Used by the runtime admission gate so - that requests admitted by startup can also be admitted at runtime. - - Safety: `SlidingWindowManager.remove_skipped_blocks` is invoked from - `allocate_slots` before each chunk's `get_num_blocks_to_allocate`, so - the per-request real-held block count plateaus at this bound. + Single source of truth for both startup pool sizing + (`max_memory_usage_bytes`) and the runtime admission gate. Per-request + real-held blocks plateau at this bound because + `SlidingWindowManager.remove_skipped_blocks` runs from `allocate_slots` + before each chunk's `get_num_blocks_to_allocate`. """ # During chunked prefill, we hold KV for the last `sliding_window-1` # computed tokens plus the newly scheduled tokens, and never more From 86271c0a0c85d601c405d4e7863bddb4ba4221fd Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 26 Apr 2026 18:19:07 -0700 Subject: [PATCH 3/4] simplify Signed-off-by: Nick Hill --- vllm/v1/core/single_type_kv_cache_manager.py | 106 ++++--------------- vllm/v1/kv_cache_interface.py | 22 ++-- 2 files changed, 33 insertions(+), 95 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 3e6b156cabf8..0e3c8cbbd96e 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -40,6 +40,7 @@ def __init__( kv_cache_group_id: int, dcp_world_size: int = 1, pcp_world_size: int = 1, + max_admission_blocks_per_request: int | None = None, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -47,6 +48,12 @@ def __init__( kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. kv_cache_group_id: The id of the kv cache group of this manager. + max_admission_blocks_per_request: Recycling-aware per-request + block cap used by `get_num_blocks_to_allocate`. Only set for + spec types that recycle blocks across chunks (SWA, + chunked-local); `None` (the default) means no cap, which is + correct for full-attention-style specs that hold every + block until the request finishes. """ self.block_size = kv_cache_spec.block_size self.dcp_world_size = dcp_world_size @@ -56,6 +63,7 @@ def __init__( self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool self.enable_caching = enable_caching + self._max_admission_blocks_per_request = max_admission_blocks_per_request self.new_block_ids: list[int] = [] # Mapping from request ID to blocks to track the blocks allocated @@ -104,6 +112,19 @@ def get_num_blocks_to_allocate( """ num_required_blocks = cdiv(num_tokens, self.block_size) + if self._max_admission_blocks_per_request is not None: + # Recycling-aware specs (SWA, chunked-local) cap the per-request + # reservation here so admission matches the startup pool sizer + # (`SlidingWindowSpec.max_admission_blocks_per_request` / its + # chunked-local counterpart). `remove_skipped_blocks` runs from + # `allocate_slots` before each chunk's `get_num_blocks_to_allocate`, + # so per-request peak real-held blocks <= this cap, which keeps + # `sum(reservations) <= pool` <=> `sum(peak_real_held) <= pool`. + # Drift between the two would re-introduce the deadlock from + # issue #39734 or, worse, mid-prefill OOM. + num_required_blocks = min( + num_required_blocks, self._max_admission_blocks_per_request + ) num_req_blocks = len(self.req_to_blocks.get(request_id, ())) if request_id in self.num_cached_block: @@ -479,55 +500,9 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class SlidingWindowManager(SingleTypeKVCacheManager): - def __init__( - self, - kv_cache_spec: SlidingWindowSpec, - *, - max_admission_blocks_per_request: int, - **kwargs, - ) -> None: + def __init__(self, kv_cache_spec: SlidingWindowSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) self.sliding_window = kv_cache_spec.sliding_window - self._max_admission_blocks_per_request = max_admission_blocks_per_request - - def get_num_blocks_to_allocate( - self, - request_id: str, - num_tokens: int, - new_computed_blocks: Sequence[KVCacheBlock], - total_computed_tokens: int, - num_tokens_main_model: int, - ) -> int: - """Return the admission *reservation* (not lifetime block touches). - - For a fresh, request the base implementation asks for - ``cdiv(num_tokens, block_size)`` blocks because - ``get_num_skipped_tokens(0) == 0``. That over-counts: chunked prefill - invokes `remove_skipped_blocks` between chunks, which swaps - out-of-window blocks for ``null_block`` and returns their slots to - the pool, so per-request real-held blocks plateau at roughly one - window's worth. - - We cap demand at the same recycling-aware bound that - `SlidingWindowSpec.max_memory_usage_bytes` used to size the pool at - startup; both call `SlidingWindowSpec.max_admission_blocks_per_request`, - the single source of truth -- drift would re-introduce the deadlock - from issue #39734 or, worse, mid-prefill OOM. - - Safety: per-request peak <= reservation (held by - `remove_skipped_blocks`), and the admission gate ensures - ``sum(reservations) <= pool``, so ``sum(peak_real_held) <= pool``. - """ - capped_num_tokens = min( - num_tokens, self._max_admission_blocks_per_request * self.block_size - ) - return super().get_num_blocks_to_allocate( - request_id, - capped_num_tokens, - new_computed_blocks, - total_computed_tokens, - num_tokens_main_model, - ) @classmethod def find_longest_cache_hit( @@ -664,44 +639,9 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): - def __init__( - self, - kv_cache_spec: ChunkedLocalAttentionSpec, - *, - max_admission_blocks_per_request: int, - **kwargs, - ) -> None: + def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) self.attention_chunk_size = kv_cache_spec.attention_chunk_size - self._max_admission_blocks_per_request = max_admission_blocks_per_request - - def get_num_blocks_to_allocate( - self, - request_id: str, - num_tokens: int, - new_computed_blocks: Sequence[KVCacheBlock], - total_computed_tokens: int, - num_tokens_main_model: int, - ) -> int: - """Return the admission *reservation* (not lifetime block touches). - - Caps demand at the recycling-aware bound from - `ChunkedLocalAttentionSpec.max_admission_blocks_per_request`, which - the startup pool sizer (`max_memory_usage_bytes`) also uses. See - `SlidingWindowManager.get_num_blocks_to_allocate` for the - recycling-vs-reservation safety argument; the same invariant holds - here via `remove_skipped_blocks`. - """ - capped_num_tokens = min( - num_tokens, self._max_admission_blocks_per_request * self.block_size - ) - return super().get_num_blocks_to_allocate( - request_id, - capped_num_tokens, - new_computed_blocks, - total_computed_tokens, - num_tokens_main_model, - ) @classmethod def find_longest_cache_hit( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 018570367b72..17284411d98d 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -354,13 +354,12 @@ def max_admission_blocks_per_request( return cdiv(num_tokens, self.block_size) def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - return ( - self.max_admission_blocks_per_request( - vllm_config.scheduler_config.max_num_batched_tokens, - vllm_config.model_config.max_model_len, - ) - * self.page_size_bytes + max_model_len = vllm_config.model_config.max_model_len + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + max_blocks = self.max_admission_blocks_per_request( + max_model_len, max_num_batched_tokens ) + return max_blocks * self.page_size_bytes @dataclass(frozen=True, kw_only=True) @@ -407,13 +406,12 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: assert vllm_config.parallel_config.decode_context_parallel_size == 1, ( "DCP not support sliding window." ) - return ( - self.max_admission_blocks_per_request( - vllm_config.scheduler_config.max_num_batched_tokens, - vllm_config.model_config.max_model_len, - ) - * self.page_size_bytes + max_model_len = vllm_config.model_config.max_model_len + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + max_blocks = self.max_admission_blocks_per_request( + max_model_len, max_num_batched_tokens ) + return max_blocks * self.page_size_bytes @dataclass(frozen=True) From 66876ce34634fece4d1c74dacd2db3e32d351cd5 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 26 Apr 2026 19:55:04 -0700 Subject: [PATCH 4/4] fix Signed-off-by: Nick Hill --- vllm/v1/core/single_type_kv_cache_manager.py | 3 ++- vllm/v1/kv_cache_interface.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 0e3c8cbbd96e..6700dbbccf55 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -1159,7 +1159,8 @@ def get_manager_for_kv_cache_spec( if isinstance(kv_cache_spec, (SlidingWindowSpec, ChunkedLocalAttentionSpec)): kwargs["max_admission_blocks_per_request"] = ( kv_cache_spec.max_admission_blocks_per_request( - max_num_batched_tokens, max_model_len + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, ) ) manager = manager_class(kv_cache_spec, **kwargs) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 17284411d98d..6bcaabf1c043 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -357,7 +357,7 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens max_blocks = self.max_admission_blocks_per_request( - max_model_len, max_num_batched_tokens + max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len ) return max_blocks * self.page_size_bytes @@ -409,7 +409,7 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens max_blocks = self.max_admission_blocks_per_request( - max_model_len, max_num_batched_tokens + max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len ) return max_blocks * self.page_size_bytes