From 79485997179bfce2d4bc1dfda662f34067b32cf5 Mon Sep 17 00:00:00 2001 From: Jhao-Ting Chen Date: Tue, 14 Apr 2026 20:42:26 -0700 Subject: [PATCH 1/2] [Scheduler] Cap SWA admission budget at sliding_window + chunk_size For hybrid SWA+full-attention models (e.g., Gemma 4), the can_fit_full_sequence admission gate passes full_num_tokens to get_num_blocks_to_allocate for all layer groups, including sliding window groups. Since total_computed_tokens is 0 for new requests, get_num_skipped_tokens returns 0, causing SWA groups to budget ceil(full_num_tokens / block_size) blocks instead of the window- sized amount they actually need. This over-budget throttles concurrent request admission. On Gemma 4 31B with 50 SWA layers (window=1024) and max_num_batched_tokens=8192, each SWA group budgets 1001 blocks instead of 576, causing 4 concurrent 65K-context sessions to be serialized through the gate. Fix: In KVCacheCoordinator.get_num_blocks_to_allocate, cap effective_num_tokens for SlidingWindowManager groups at sliding_window + max_num_batched_tokens. The window term is the steady-state max blocks, and the chunk term accounts for blocks needed during a single prefill chunk before remove_skipped_blocks frees OOW blocks. This matches TensorRT-LLM's getNeededBlocksOneStep. Plumbing: max_num_batched_tokens flows from SchedulerConfig through KVCacheManager and get_kv_cache_coordinator to all coordinator subclasses. Signed-off-by: Jhao-Ting Chen Co-Authored-By: Claude Opus 4.6 (1M context) --- vllm/v1/core/kv_cache_coordinator.py | 25 ++++++++++++++++++++++++- vllm/v1/core/kv_cache_manager.py | 3 +++ vllm/v1/core/sched/scheduler.py | 1 + 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index eaa95dfe49f7..ae016455fc95 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -15,6 +15,7 @@ from vllm.v1.core.single_type_kv_cache_manager import ( CrossAttentionManager, SingleTypeKVCacheManager, + SlidingWindowManager, get_manager_for_kv_cache_spec, ) from vllm.v1.kv_cache_interface import ( @@ -40,10 +41,12 @@ def __init__( dcp_world_size: int, pcp_world_size: int, hash_block_size: int, + max_num_batched_tokens: int = 0, metrics_collector: KVCacheMetricsCollector | None = None, ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len + self.max_num_batched_tokens = max_num_batched_tokens self.enable_caching = enable_caching self.block_pool = BlockPool( @@ -105,9 +108,19 @@ def get_num_blocks_to_allocate( request_id, num_encoder_tokens, [], 0, num_encoder_tokens ) else: + # Cap num_tokens for SWA: a sliding window layer never + # holds more than sliding_window blocks at steady state. + # OOW blocks are freed between chunks by + # remove_skipped_blocks(), so the admission check only + # needs to budget for the window, not the full sequence. + effective_num_tokens = num_tokens + if isinstance(manager, SlidingWindowManager): + effective_num_tokens = min( + num_tokens, manager.sliding_window + self.max_num_batched_tokens + ) num_blocks_to_allocate += manager.get_num_blocks_to_allocate( request_id, - num_tokens, + effective_num_tokens, new_computed_blocks[i], total_computed_tokens, num_tokens_main_model, @@ -270,6 +283,7 @@ def __init__( dcp_world_size: int, pcp_world_size: int, hash_block_size: int, + max_num_batched_tokens: int = 0, metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( @@ -281,6 +295,7 @@ def __init__( dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, hash_block_size=hash_block_size, + max_num_batched_tokens=max_num_batched_tokens, metrics_collector=metrics_collector, ) self.num_single_type_manager = len(self.single_type_managers) @@ -316,6 +331,7 @@ def __init__( dcp_world_size: int, pcp_world_size: int, hash_block_size: int, + max_num_batched_tokens: int = 0, metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( @@ -327,6 +343,7 @@ def __init__( dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, hash_block_size=hash_block_size, + max_num_batched_tokens=max_num_batched_tokens, metrics_collector=metrics_collector, ) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec @@ -381,6 +398,7 @@ def __init__( dcp_world_size: int, pcp_world_size: int, hash_block_size: int, + max_num_batched_tokens: int = 0, metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( @@ -392,6 +410,7 @@ def __init__( dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, hash_block_size=hash_block_size, + max_num_batched_tokens=max_num_batched_tokens, metrics_collector=metrics_collector, ) # hash_block_size: the block size used to compute block hashes. @@ -547,6 +566,7 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: def get_kv_cache_coordinator( kv_cache_config: KVCacheConfig, max_model_len: int, + max_num_batched_tokens: int, use_eagle: bool, enable_caching: bool, enable_kv_cache_events: bool, @@ -564,6 +584,7 @@ def get_kv_cache_coordinator( dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, hash_block_size=hash_block_size, + max_num_batched_tokens=max_num_batched_tokens, metrics_collector=metrics_collector, ) if len(kv_cache_config.kv_cache_groups) == 1: @@ -576,6 +597,7 @@ def get_kv_cache_coordinator( dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, hash_block_size=hash_block_size, + max_num_batched_tokens=max_num_batched_tokens, metrics_collector=metrics_collector, ) return HybridKVCacheCoordinator( @@ -587,5 +609,6 @@ def get_kv_cache_coordinator( dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, hash_block_size=hash_block_size, + max_num_batched_tokens=max_num_batched_tokens, metrics_collector=metrics_collector, ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index dcec5e05bf97..3830f9f566c4 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -109,6 +109,7 @@ def __init__( kv_cache_config: KVCacheConfig, max_model_len: int, hash_block_size: int, + max_num_batched_tokens: int = 0, enable_caching: bool = True, use_eagle: bool = False, log_stats: bool = False, @@ -118,6 +119,7 @@ def __init__( metrics_collector: KVCacheMetricsCollector | None = None, ) -> None: self.max_model_len = max_model_len + self.max_num_batched_tokens = max_num_batched_tokens self.enable_caching = enable_caching self.use_eagle = use_eagle @@ -131,6 +133,7 @@ def __init__( self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_batched_tokens, use_eagle=self.use_eagle, enable_caching=self.enable_caching, enable_kv_cache_events=enable_kv_cache_events, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f6e96c677485..9479c8dd6889 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -225,6 +225,7 @@ def __init__( self.kv_cache_manager = KVCacheManager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, + max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, enable_caching=self.cache_config.enable_prefix_caching, use_eagle=self.use_eagle, log_stats=self.log_stats, From bbfb686fbf216313729bc23a47e0463704774c04 Mon Sep 17 00:00:00 2001 From: Josh Ferguson Date: Wed, 15 Apr 2026 23:10:36 -0500 Subject: [PATCH 2/2] Guard uncapped SWA allocation after capped admission --- tests/v1/core/test_prefix_caching.py | 53 ++++++++++++++++++++++++ vllm/v1/core/kv_cache_coordinator.py | 60 ++++++++++++++++++++++------ vllm/v1/core/kv_cache_manager.py | 18 ++++++++- 3 files changed, 117 insertions(+), 14 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 22220599f158..c159d52fad68 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1831,6 +1831,59 @@ def test_reset_prefix_cache(): assert all([blk.block_hash is None for blk in manager.block_pool.blocks]) +def test_hybrid_swa_cap_does_not_crash_allocator(): + block_size = 16 + sliding_window = 64 + num_tokens = 200 + request_id = "r" + + manager = KVCacheManager( + kv_cache_config=KVCacheConfig( + num_blocks=26, # 25 usable blocks + 1 null block. + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["full"], + FullAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), + ), + KVCacheGroupSpec( + ["swa"], + SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=sliding_window, + ), + ), + ], + ), + max_model_len=4096, + hash_block_size=block_size, + max_num_batched_tokens=32, + enable_caching=True, + ) + + req = make_request(request_id, [1] * num_tokens, block_size, sha256) + full_required_blocks = (num_tokens + block_size - 1) // block_size + allocated = manager.block_pool.get_new_blocks( + manager.block_pool.num_gpu_blocks - 1) + full_mgr, swa_mgr = manager.coordinator.single_type_managers + + full_mgr.req_to_blocks[request_id] = allocated[:full_required_blocks] + swa_mgr.req_to_blocks[request_id] = allocated[ + full_required_blocks:full_required_blocks + (full_required_blocks - 1) + ] + + assert manager.block_pool.get_num_free_blocks() == 0 + assert manager.allocate_slots(req, num_new_tokens=num_tokens) is None + + def test_prefix_cache_stats_disabled(): """Test that prefix_cache_stats is None when log_stats is False.""" block_size = 16 diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index ae016455fc95..027564e1f219 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -71,7 +71,7 @@ def __init__( for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups) ) - def get_num_blocks_to_allocate( + def get_num_blocks_needed_for_admission( self, request_id: str, num_tokens: int, @@ -81,7 +81,12 @@ def get_num_blocks_to_allocate( num_tokens_main_model: int, ) -> int: """ - Get the number of blocks needed to be allocated for the request. + Get the number of blocks needed for admission. + + This is an admission estimate, not the exact allocator demand. For SWA + groups, cap the token count at sliding_window + one prefill chunk. + OOW blocks are freed between chunks by remove_skipped_blocks(), so the + scheduler does not need to budget the full sequence length here. Args: request_id: The request ID. @@ -108,16 +113,8 @@ def get_num_blocks_to_allocate( request_id, num_encoder_tokens, [], 0, num_encoder_tokens ) else: - # Cap num_tokens for SWA: a sliding window layer never - # holds more than sliding_window blocks at steady state. - # OOW blocks are freed between chunks by - # remove_skipped_blocks(), so the admission check only - # needs to budget for the window, not the full sequence. - effective_num_tokens = num_tokens - if isinstance(manager, SlidingWindowManager): - effective_num_tokens = min( - num_tokens, manager.sliding_window + self.max_num_batched_tokens - ) + effective_num_tokens = self._get_admission_num_tokens( + manager, num_tokens) num_blocks_to_allocate += manager.get_num_blocks_to_allocate( request_id, effective_num_tokens, @@ -127,6 +124,34 @@ def get_num_blocks_to_allocate( ) return num_blocks_to_allocate + def get_num_blocks_to_allocate( + self, + request_id: str, + num_tokens: int, + new_computed_blocks: tuple[Sequence[KVCacheBlock], ...], + num_encoder_tokens: int, + total_computed_tokens: int, + num_tokens_main_model: int, + ) -> int: + """ + Get the exact number of blocks needed to allocate for the request. + """ + num_blocks_to_allocate = 0 + for i, manager in enumerate(self.single_type_managers): + if isinstance(manager, CrossAttentionManager): + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_encoder_tokens, [], 0, num_encoder_tokens + ) + else: + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, + num_tokens, + new_computed_blocks[i], + total_computed_tokens, + num_tokens_main_model, + ) + return num_blocks_to_allocate + def allocate_new_computed_blocks( self, request_id: str, @@ -188,6 +213,17 @@ def allocate_new_blocks( for manager in self.single_type_managers ) + def _get_admission_num_tokens( + self, + manager: SingleTypeKVCacheManager, + num_tokens: int, + ) -> int: + # SWA layers never need more than their window plus one prefill chunk. + if isinstance(manager, SlidingWindowManager): + return min(num_tokens, + manager.sliding_window + self.max_num_batched_tokens) + return num_tokens + def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """ Cache the blocks for the request. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 3830f9f566c4..0b3d78bcb96a 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -377,7 +377,7 @@ def allocate_slots( request.request_id, total_computed_tokens ) - num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( + block_allocation_kwargs = dict( request_id=request.request_id, num_tokens=num_tokens_need_slot, new_computed_blocks=new_computed_block_list, @@ -387,10 +387,24 @@ def allocate_slots( num_tokens_main_model=num_tokens_main_model, ) - if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + num_blocks_to_allocate = ( + self.coordinator.get_num_blocks_needed_for_admission( + **block_allocation_kwargs)) + + num_free_blocks = self.block_pool.get_num_free_blocks() + if num_blocks_to_allocate > num_free_blocks: # Cannot allocate new blocks return None + # The SWA-capped admission estimate can be lower than the actual + # allocator demand when a request already owns some blocks. Re-check + # with the uncapped token count so we fail cleanly instead of throwing + # out of BlockPool.get_new_blocks(). + actual_num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( + **block_allocation_kwargs) + if actual_num_blocks_to_allocate > num_free_blocks: + return None + if ( new_computed_block_list is not self.empty_kv_cache_blocks.blocks or num_external_computed_tokens > 0