From ad830a26265a8256556a6b22d03520d331913b1d Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Fri, 20 Mar 2026 09:01:34 -0400 Subject: [PATCH 01/14] Marconi admission policy for hybrid cache. Signed-off-by: Stanislaw Wozniak --- vllm/v1/core/kv_cache_coordinator.py | 10 ++++++-- vllm/v1/core/sched/scheduler.py | 34 +++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index eaa95dfe49f7..081992a6785d 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -482,6 +482,7 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: num_groups = len(self.kv_cache_config.kv_cache_groups) hit_length = max_cache_hit_length + longest_hit_length = 0 hit_blocks_by_group: list[list[KVCacheBlock] | None] = [None] * num_groups # Simple hybrid (1 full attn + 1 other): one iteration suffices. @@ -523,7 +524,12 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: curr_hit_length = len(hit_blocks[0]) * spec.block_size for group_id, blocks in zip(group_ids, hit_blocks): hit_blocks_by_group[group_id] = blocks - + + # Collect information on the longest cached prefix overall + # (no matter the attention type) to allow for more complex + # caching policies + longest_hit_length = max(longest_hit_length, curr_hit_length) + if curr_hit_length >= hit_length: break hit_length = curr_hit_length @@ -541,7 +547,7 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: return tuple( blocks if blocks is not None else [] for blocks in hit_blocks_by_group - ), hit_length + ), longest_hit_length # longest_hit_length >= hit_length def get_kv_cache_coordinator( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 486ce8debc88..abc4c8a2aff2 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -291,6 +291,7 @@ def _mamba_block_aligned_split( num_new_tokens: int, num_new_local_computed_tokens: int = 0, num_external_computed_tokens: int = 0, + mamba_tokens_lag: int = 0, ) -> int: assert num_external_computed_tokens == 0, ( "External KV connector is not verified yet" @@ -333,6 +334,21 @@ def _mamba_block_aligned_split( else: # prefill the last few tokens pass + + # Marconi cache admission optimization: + # Create cache entries at divergence points of common prefixes. + # + # Implementation: + # If mamba cache "lags" behind the KVCache hits for >= block_size, + # there is a common shared prefix that wasn't cached. + if mamba_tokens_lag >= block_size: + # If num_new_tokens is longer than lag, + # the prefix normally still wouldn't be cached + if num_new_tokens > mamba_tokens_lag: + # So we force caching at mamba_tokens_lag + num_new_tokens = mamba_tokens_lag + assert mamba_tokens_lag % block_size == 0 #TODO? + #num_new_tokens = num_new_tokens // block_size * block_size return num_new_tokens def schedule(self) -> SchedulerOutput: @@ -602,6 +618,21 @@ def schedule(self) -> SchedulerOutput: new_computed_blocks, num_new_local_computed_tokens = ( self.kv_cache_manager.get_computed_blocks(request) ) + + # More proper check would be: + # if isinstance(self.kv_cache_manager.coordinator, + # HybridKVCacheCoordinator): + # but this check is similar and avoids + # importing HybridKVCacheCoordinator: + if self.has_mamba_layers: + # HybridKVCacheCoordinator returns the longest hit: + longest_hit_length = num_new_local_computed_tokens + # Obtain the shortest cached prefix from the blocks: + num_new_local_computed_tokens = \ + len(new_computed_blocks.blocks[0]) * self.block_size + # Mamba tokens "lag" - how far it's behind longest hit: + mamba_tokens_lag = \ + longest_hit_length - num_new_local_computed_tokens # Get externally-cached tokens if using a KVConnector. if self.connector is not None: @@ -689,11 +720,12 @@ def schedule(self) -> SchedulerOutput: break if self.need_mamba_block_aligned_split: - num_new_tokens = self._mamba_block_aligned_split( + num_new_tokens = self._mamba_block_aligned_split( # TODO: HERE SPLIT should know the FA length and break there first request, num_new_tokens, num_new_local_computed_tokens, num_external_computed_tokens, + mamba_tokens_lag, ) if num_new_tokens == 0: break From d250f8d5d824f11c9461d009e3d9d60cea0f40f5 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 23 Mar 2026 11:04:36 -0400 Subject: [PATCH 02/14] Cleanup. Signed-off-by: Stanislaw Wozniak --- vllm/v1/core/sched/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index abc4c8a2aff2..dd153cb69495 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -720,7 +720,7 @@ def schedule(self) -> SchedulerOutput: break if self.need_mamba_block_aligned_split: - num_new_tokens = self._mamba_block_aligned_split( # TODO: HERE SPLIT should know the FA length and break there first + num_new_tokens = self._mamba_block_aligned_split( request, num_new_tokens, num_new_local_computed_tokens, From 54eedc20be3f67ece23847e5c4aa17c377f4f020 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Thu, 26 Mar 2026 13:51:35 -0400 Subject: [PATCH 03/14] Feedback from Gemini Signed-off-by: Stanislaw Wozniak --- vllm/v1/core/sched/scheduler.py | 40 ++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index dd153cb69495..60afc9447516 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -338,17 +338,24 @@ def _mamba_block_aligned_split( # Marconi cache admission optimization: # Create cache entries at divergence points of common prefixes. # - # Implementation: - # If mamba cache "lags" behind the KVCache hits for >= block_size, - # there is a common shared prefix that wasn't cached. - if mamba_tokens_lag >= block_size: - # If num_new_tokens is longer than lag, - # the prefix normally still wouldn't be cached - if num_new_tokens > mamba_tokens_lag: - # So we force caching at mamba_tokens_lag - num_new_tokens = mamba_tokens_lag - assert mamba_tokens_lag % block_size == 0 #TODO? - #num_new_tokens = num_new_tokens // block_size * block_size + # Implementation: + # If uncached common prefix (mamba_tokens_lag) is long enough + # to justify its caching ( >= block_size) + # AND + # currently scheduled token count is longer than the common prefix + if mamba_tokens_lag >= block_size and \ + num_new_tokens > mamba_tokens_lag: + # Then force to cache at the end of the common prefix + # by limiting the num_new_tokens to the length of that prefix: + num_new_tokens = mamba_tokens_lag + # This should be still block aligned as: + # - token hit counts are block aligned + # - thus mamba_tokens_lag is block aligned + # - attention and mamba block sizes are equal + # Optionally, we can verify this: + assert mamba_tokens_lag % block_size == 0 + # Or force block re-alignment: + # num_new_tokens = num_new_tokens // block_size * block_size return num_new_tokens def schedule(self) -> SchedulerOutput: @@ -627,12 +634,15 @@ def schedule(self) -> SchedulerOutput: if self.has_mamba_layers: # HybridKVCacheCoordinator returns the longest hit: longest_hit_length = num_new_local_computed_tokens - # Obtain the shortest cached prefix from the blocks: - num_new_local_computed_tokens = \ + # HybridKVCacheCoordinator returns the blocks of + # the common hit, from which we obtain the hit length: + common_hit_length = \ len(new_computed_blocks.blocks[0]) * self.block_size - # Mamba tokens "lag" - how far it's behind longest hit: + # How many tokens mamba cache is behind the longest hit: mamba_tokens_lag = \ - longest_hit_length - num_new_local_computed_tokens + longest_hit_length - common_hit_length + # Resume default scheduler logic based on the common hit + num_new_local_computed_tokens = common_hit_length # Get externally-cached tokens if using a KVConnector. if self.connector is not None: From 72809a4853176408653a2b94a79d3e63bfab1fbd Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Thu, 26 Mar 2026 13:53:48 -0400 Subject: [PATCH 04/14] Pre-commit fixes Signed-off-by: Stanislaw Wozniak --- vllm/v1/core/kv_cache_coordinator.py | 6 +++--- vllm/v1/core/sched/scheduler.py | 23 +++++++++++------------ 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 081992a6785d..03958436920d 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -524,12 +524,12 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: curr_hit_length = len(hit_blocks[0]) * spec.block_size for group_id, blocks in zip(group_ids, hit_blocks): hit_blocks_by_group[group_id] = blocks - + # Collect information on the longest cached prefix overall # (no matter the attention type) to allow for more complex # caching policies longest_hit_length = max(longest_hit_length, curr_hit_length) - + if curr_hit_length >= hit_length: break hit_length = curr_hit_length @@ -547,7 +547,7 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: return tuple( blocks if blocks is not None else [] for blocks in hit_blocks_by_group - ), longest_hit_length # longest_hit_length >= hit_length + ), longest_hit_length # longest_hit_length >= hit_length def get_kv_cache_coordinator( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 60afc9447516..4f8d63e490e2 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -334,17 +334,16 @@ def _mamba_block_aligned_split( else: # prefill the last few tokens pass - + # Marconi cache admission optimization: # Create cache entries at divergence points of common prefixes. - # - # Implementation: - # If uncached common prefix (mamba_tokens_lag) is long enough - # to justify its caching ( >= block_size) + # + # Implementation: + # If uncached common prefix (mamba_tokens_lag) is long enough + # to justify its caching ( >= block_size) # AND # currently scheduled token count is longer than the common prefix - if mamba_tokens_lag >= block_size and \ - num_new_tokens > mamba_tokens_lag: + if mamba_tokens_lag >= block_size and num_new_tokens > mamba_tokens_lag: # Then force to cache at the end of the common prefix # by limiting the num_new_tokens to the length of that prefix: num_new_tokens = mamba_tokens_lag @@ -625,9 +624,9 @@ def schedule(self) -> SchedulerOutput: new_computed_blocks, num_new_local_computed_tokens = ( self.kv_cache_manager.get_computed_blocks(request) ) - + # More proper check would be: - # if isinstance(self.kv_cache_manager.coordinator, + # if isinstance(self.kv_cache_manager.coordinator, # HybridKVCacheCoordinator): # but this check is similar and avoids # importing HybridKVCacheCoordinator: @@ -636,11 +635,11 @@ def schedule(self) -> SchedulerOutput: longest_hit_length = num_new_local_computed_tokens # HybridKVCacheCoordinator returns the blocks of # the common hit, from which we obtain the hit length: - common_hit_length = \ + common_hit_length = ( len(new_computed_blocks.blocks[0]) * self.block_size + ) # How many tokens mamba cache is behind the longest hit: - mamba_tokens_lag = \ - longest_hit_length - common_hit_length + mamba_tokens_lag = longest_hit_length - common_hit_length # Resume default scheduler logic based on the common hit num_new_local_computed_tokens = common_hit_length From e67adb2ec5ab171b370c3da390cdd6305a999f66 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 22 Apr 2026 09:58:08 -0400 Subject: [PATCH 05/14] Small naming changes Signed-off-by: Stanislaw Wozniak --- vllm/v1/core/sched/scheduler.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4f8d63e490e2..6f4a9eebda96 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -291,7 +291,7 @@ def _mamba_block_aligned_split( num_new_tokens: int, num_new_local_computed_tokens: int = 0, num_external_computed_tokens: int = 0, - mamba_tokens_lag: int = 0, + num_uncached_common_prefix_tokens: int = 0, ) -> int: assert num_external_computed_tokens == 0, ( "External KV connector is not verified yet" @@ -339,20 +339,23 @@ def _mamba_block_aligned_split( # Create cache entries at divergence points of common prefixes. # # Implementation: - # If uncached common prefix (mamba_tokens_lag) is long enough - # to justify its caching ( >= block_size) + # If uncached common prefix (num_uncached_common_prefix_tokens) + # is long enough to justify its caching ( >= block_size) # AND # currently scheduled token count is longer than the common prefix - if mamba_tokens_lag >= block_size and num_new_tokens > mamba_tokens_lag: + if ( + num_uncached_common_prefix_tokens >= block_size + and num_new_tokens > num_uncached_common_prefix_tokens + ): # Then force to cache at the end of the common prefix # by limiting the num_new_tokens to the length of that prefix: - num_new_tokens = mamba_tokens_lag + num_new_tokens = num_uncached_common_prefix_tokens # This should be still block aligned as: # - token hit counts are block aligned - # - thus mamba_tokens_lag is block aligned + # - thus num_uncached_common_prefix_tokens is block aligned # - attention and mamba block sizes are equal # Optionally, we can verify this: - assert mamba_tokens_lag % block_size == 0 + assert num_new_tokens % block_size == 0 # Or force block re-alignment: # num_new_tokens = num_new_tokens // block_size * block_size return num_new_tokens @@ -639,7 +642,9 @@ def schedule(self) -> SchedulerOutput: len(new_computed_blocks.blocks[0]) * self.block_size ) # How many tokens mamba cache is behind the longest hit: - mamba_tokens_lag = longest_hit_length - common_hit_length + num_uncached_common_prefix_tokens = ( + longest_hit_length - common_hit_length + ) # Resume default scheduler logic based on the common hit num_new_local_computed_tokens = common_hit_length @@ -734,7 +739,7 @@ def schedule(self) -> SchedulerOutput: num_new_tokens, num_new_local_computed_tokens, num_external_computed_tokens, - mamba_tokens_lag, + num_uncached_common_prefix_tokens, ) if num_new_tokens == 0: break From 8ae4013989f5d1a7a9dda2c99ea56c15b4c36ac6 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 22 Apr 2026 11:42:46 -0400 Subject: [PATCH 06/14] Cleaner version returning common prefix from cache coordinator Signed-off-by: Stanislaw Wozniak --- vllm/v1/core/kv_cache_coordinator.py | 27 ++++++++++++++++++--------- vllm/v1/core/kv_cache_manager.py | 15 +++++++++++---- vllm/v1/core/sched/scheduler.py | 28 +++++----------------------- 3 files changed, 34 insertions(+), 36 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 03958436920d..b4ef0643e716 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -244,7 +244,7 @@ def find_longest_cache_hit( self, block_hashes: list[BlockHash], max_cache_hit_length: int, - ) -> tuple[tuple[list[KVCacheBlock], ...], int]: + ) -> tuple[tuple[list[KVCacheBlock], ...], int, int]: pass def new_step_starts(self) -> None: @@ -292,11 +292,11 @@ def find_longest_cache_hit( self, block_hashes: list[BlockHash], max_cache_hit_length: int, - ) -> tuple[tuple[list[KVCacheBlock], ...], int]: + ) -> tuple[tuple[list[KVCacheBlock], ...], int, int]: blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(self.num_single_type_manager) ) - return blocks, 0 + return blocks, 0, 0 class UnitaryKVCacheCoordinator(KVCacheCoordinator): @@ -350,7 +350,7 @@ def find_longest_cache_hit( self, block_hashes: list[BlockHash], max_cache_hit_length: int, - ) -> tuple[tuple[list[KVCacheBlock], ...], int]: + ) -> tuple[tuple[list[KVCacheBlock], ...], int, int]: hit_blocks = self.single_type_managers[0].find_longest_cache_hit( block_hashes=block_hashes, max_length=max_cache_hit_length, @@ -362,7 +362,7 @@ def find_longest_cache_hit( dcp_world_size=self.dcp_world_size, pcp_world_size=self.pcp_world_size, ) - return hit_blocks, len(hit_blocks[0]) * self.block_size + return hit_blocks, len(hit_blocks[0]) * self.block_size, 0 class HybridKVCacheCoordinator(KVCacheCoordinator): @@ -454,7 +454,7 @@ def find_longest_cache_hit( self, block_hashes: list[BlockHash], max_cache_hit_length: int, - ) -> tuple[tuple[list[KVCacheBlock], ...], int]: + ) -> tuple[tuple[list[KVCacheBlock], ...], int, int]: """ Find the longest cache hit using an iterative fixed-point algorithm. @@ -545,9 +545,18 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: if (blks := hit_blocks_by_group[group_id]) is not None: del blks[num_blocks:] - return tuple( - blocks if blocks is not None else [] for blocks in hit_blocks_by_group - ), longest_hit_length # longest_hit_length >= hit_length + # Uncached shared prefix detection heuristic: + # If any attention group cached a longer prefix than the current common + # prefix, there was a request with a shared prefix of at least that + # length in the past. Return the length of such common prefix: + num_uncached_common_prefix_tokens = longest_hit_length - hit_length + return ( + tuple( + blocks if blocks is not None else [] for blocks in hit_blocks_by_group + ), + hit_length, + num_uncached_common_prefix_tokens, + ) def get_kv_cache_coordinator( diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 2c712a1b1838..3a5da1fbc39d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -173,7 +173,7 @@ def make_prefix_cache_stats(self) -> PrefixCacheStats | None: self.prefix_cache_stats = PrefixCacheStats() return stats - def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: + def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -184,13 +184,16 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: A tuple containing: - A list of blocks that are computed for the request. - The number of computed tokens. + - 0, or the number of uncached shared prefix tokens + beyond computed tokens, as detected by shared prefix detection + (currently implemented in HybridKVCacheCoordinator) """ # We skip finding the prefix cache hit when prefix caching is # disabled or the request is marked as skipping kv cache read # (which happens when the request requires prompt logprobs # or calls a pooling model with all pooling). if not self.enable_caching or request.skip_reading_prefix_cache: - return self.empty_kv_cache_blocks, 0 + return self.empty_kv_cache_blocks, 0, 0 # NOTE: When all tokens hit the cache, we must recompute the last token # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. @@ -199,7 +202,7 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: # num_computed_tokens to be block-size aligned. Removing this limitation # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 - computed_blocks, num_new_computed_tokens = ( + computed_blocks, num_new_computed_tokens, num_uncached_common_prefix_tokens = ( self.coordinator.find_longest_cache_hit( request.block_hashes, max_cache_hit_length ) @@ -213,7 +216,11 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: preempted=request.num_preemptions > 0, ) - return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens + return ( + self.create_kv_cache_blocks(computed_blocks), + num_new_computed_tokens, + num_uncached_common_prefix_tokens, + ) def allocate_slots( self, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 6f4a9eebda96..e6a7daab255d 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -624,29 +624,11 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. if request.num_computed_tokens == 0: # Get locally-cached tokens. - new_computed_blocks, num_new_local_computed_tokens = ( - self.kv_cache_manager.get_computed_blocks(request) - ) - - # More proper check would be: - # if isinstance(self.kv_cache_manager.coordinator, - # HybridKVCacheCoordinator): - # but this check is similar and avoids - # importing HybridKVCacheCoordinator: - if self.has_mamba_layers: - # HybridKVCacheCoordinator returns the longest hit: - longest_hit_length = num_new_local_computed_tokens - # HybridKVCacheCoordinator returns the blocks of - # the common hit, from which we obtain the hit length: - common_hit_length = ( - len(new_computed_blocks.blocks[0]) * self.block_size - ) - # How many tokens mamba cache is behind the longest hit: - num_uncached_common_prefix_tokens = ( - longest_hit_length - common_hit_length - ) - # Resume default scheduler logic based on the common hit - num_new_local_computed_tokens = common_hit_length + ( + new_computed_blocks, + num_new_local_computed_tokens, + num_uncached_common_prefix_tokens, + ) = self.kv_cache_manager.get_computed_blocks(request) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: From 0e3af1d8004d65e46c8c7a810c872a17a80da8e0 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Thu, 23 Apr 2026 18:14:13 -0400 Subject: [PATCH 07/14] Pre-commit fixes Signed-off-by: Stanislaw Wozniak --- vllm/v1/simple_kv_offload/manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/simple_kv_offload/manager.py b/vllm/v1/simple_kv_offload/manager.py index 5eedc07f717e..861a946c35b6 100644 --- a/vllm/v1/simple_kv_offload/manager.py +++ b/vllm/v1/simple_kv_offload/manager.py @@ -219,7 +219,7 @@ def get_num_new_matched_tokens( max_hit_len = request.num_tokens - 1 - num_computed_tokens if max_hit_len <= 0: return 0, False - _, hit_length = self.cpu_coordinator.find_longest_cache_hit( + _, hit_length, _ = self.cpu_coordinator.find_longest_cache_hit( remaining_hashes, max_hit_len ) @@ -261,7 +261,7 @@ def update_state_after_alloc( # Find CPU cached blocks across all groups. max_hit_len = len(hashes_to_load) * self.block_size - cpu_hit_blocks, hit_length = self.cpu_coordinator.find_longest_cache_hit( + cpu_hit_blocks, hit_length, _ = self.cpu_coordinator.find_longest_cache_hit( hashes_to_load, max_hit_len ) assert hit_length == num_external_tokens, ( From 9cc01e2530d52370caf767bf790e329008ac6c5a Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Fri, 24 Apr 2026 04:57:06 -0400 Subject: [PATCH 08/14] Adapt tests to pass with new return structure Signed-off-by: Stanislaw Wozniak --- tests/v1/core/test_prefix_caching.py | 102 +++++++++++++-------------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 22220599f158..f0cdef4faf49 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -221,7 +221,7 @@ def test_prefill(hash_fn): unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -251,7 +251,7 @@ def test_prefill(hash_fn): # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 @@ -285,7 +285,7 @@ def test_prefill(hash_fn): # Incomplete 1 block (6 tokens) unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) assert len(req2.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 @@ -305,7 +305,7 @@ def test_prefill(hash_fn): # Cache miss and eviction. req3 = make_request("3", [99] * (16 * 10), block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -347,7 +347,7 @@ def test_prefill_hybrid_model(): unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -383,7 +383,7 @@ def test_prefill_hybrid_model(): unique_token_ids = [3] * 5 all_token_ids = common_token_ids + unique_token_ids req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, 7], [0, 10, 11]) assert num_computed_tokens == 3 * 16 @@ -517,7 +517,7 @@ def test_prefill_hybrid_model_eagle(): unique_token_ids = [6] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == len(all_token_ids) // block_size assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -554,7 +554,7 @@ def test_prefill_hybrid_model_eagle(): unique_token_ids = [6] * 5 all_token_ids = common_token_ids + unique_token_ids req1 = make_request("1", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == num_full_blocks assert computed_blocks.get_block_ids() == ( [1, 2, 3, 4], @@ -691,7 +691,7 @@ def _test_partial_request_hit( req = make_request(request_id, prompt_token_ids, block_size, sha256) for hash_with_group_id in hash_to_evict: manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req) assert len(req.block_hashes) == num_full_blocks assert num_computed_tokens == expect_hit_length * block_size for block_per_group in computed_blocks.blocks: @@ -852,7 +852,7 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]): # First request: no cache hit initially req0 = make_request("0", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] # No cache hit initially @@ -869,7 +869,7 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]): # Second request: should hit cached blocks for common prefix req1 = make_request("1", common_token_ids + [4] * 5, block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) # Should hit cached blocks for all groups assert num_computed_tokens == 3 * block_size @@ -929,7 +929,7 @@ def test_prefill_hybrid_model_combinations_eagle( # First request: no cache hit initially req0 = make_request("0", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == num_full_blocks assert not computed_blocks.blocks[0] # No cache hit initially @@ -945,7 +945,7 @@ def test_prefill_hybrid_model_combinations_eagle( # Second request: should hit cached blocks for common prefix all_token_ids = common_token_ids + [6] * 5 req1 = make_request("1", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) # Should hit cached blocks for all groups assert num_computed_tokens == expect_hit_length * block_size @@ -998,7 +998,7 @@ def test_prefill_hybrid_model_mamba_align(): # First request: allocate_slots should not crash with the assertion error # in MambaManager.cache_blocks() when null blocks are present. req0 = make_request("0", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, num_computed_tokens, computed_blocks) @@ -1034,7 +1034,7 @@ def test_prefill_plp(): unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, block_size, hash_fn, prompt_logprobs=5) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1066,7 +1066,7 @@ def test_prefill_plp(): # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 @@ -1101,7 +1101,7 @@ def test_prefill_plp(): req2 = make_request( "2", common_token_ids + unique_token_ids, block_size, hash_fn, prompt_logprobs=5 ) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) assert len(req2.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1138,7 +1138,7 @@ def test_decode(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids, block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1196,7 +1196,7 @@ def test_evict(): last_token_id = 5 * 16 + 7 req0 = make_request("0", list(range(last_token_id)), block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1209,7 +1209,7 @@ def test_evict(): req1 = make_request( "1", list(range(last_token_id, last_token_id + 3 * 16)), block_size, sha256 ) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1230,7 +1230,7 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) assert computed_blocks.get_block_ids() == ([1, 2],) assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots( @@ -1256,7 +1256,7 @@ def test_hash_block_correct_reuse(): # Allocate 1 block and cache it. num_tokens = block_size * 1 req = make_request("0", list(range(num_tokens)), block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1270,7 +1270,7 @@ def test_hash_block_correct_reuse(): # Allocate a new block that's not full, make sure hash info on the # block is cleared. req = make_request("1", list(range(num_tokens - 1)), block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1297,7 +1297,7 @@ def test_computed_blocks_not_evicted(): # Allocate a block and cache it. num_tokens = block_size * 1 req0 = make_request("0", list(range(num_tokens)), block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1310,7 +1310,7 @@ def test_computed_blocks_not_evicted(): req1 = make_request( "1", list(range(num_tokens, num_tokens * 2)), block_size, sha256 ) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1326,7 +1326,7 @@ def test_computed_blocks_not_evicted(): # Now if we have a cache hit on the first block, we should evict the second # cached block rather than the first one. req2 = make_request("2", list(range(num_tokens * 2)), block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 1 assert computed_blocks.blocks[0][0].block_id == 1 assert num_computed_tokens == block_size @@ -1357,7 +1357,7 @@ def test_basic_prefix_caching_disabled(): "1", list(range(10)), block_size, sha256 ) # 2 blocks and some more - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1370,7 +1370,7 @@ def test_basic_prefix_caching_disabled(): # No caching. req2 = make_request("2", list(range(16)), block_size, sha256) # shared prefix - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1380,7 +1380,7 @@ def test_basic_prefix_caching_disabled(): # New requests should not have any blocks. req3 = make_request("3", list(range(4)), block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1562,7 +1562,7 @@ def test_mm_prefix_caching(): mm_positions=mm_positions, mm_hashes=mm_hashes, ) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) # Completed block should have hashes assert not computed_blocks.blocks[0] @@ -1627,7 +1627,7 @@ def test_mm_prefix_caching(): mm_positions=mm_positions, mm_hashes=mm_hashes, ) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) assert len(computed_blocks.blocks[0]) == 3 assert num_computed_tokens == 3 * 16 @@ -1649,7 +1649,7 @@ def test_cache_key_salting(): common_token_ids = [i for i in range(3) for _ in range(block_size)] token_ids = common_token_ids + [3] * 11 req0 = make_request("0", token_ids, block_size, sha256, cache_salt="salt1") - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) # Completed block should have hashes assert not computed_blocks.blocks[0] @@ -1688,7 +1688,7 @@ def test_cache_key_salting(): # Test cache hit with a new request that has the same salt. token_ids = common_token_ids + [4] * 11 req1 = make_request("1", token_ids, block_size, sha256, cache_salt="salt1") - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. assert len(computed_blocks.blocks[0]) == 3 assert num_computed_tokens == 3 * block_size @@ -1696,7 +1696,7 @@ def test_cache_key_salting(): # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 req2 = make_request("2", token_ids, block_size, sha256, cache_salt="salt2") - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 0 assert num_computed_tokens == 0 block_hashes = req2.block_hashes @@ -1730,7 +1730,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | ... | common_token_ids = [i for i in range(3) for _ in range(16)] req0 = make_request("0", common_token_ids, block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots( @@ -1742,7 +1742,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2, block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 manager.allocate_slots( @@ -1760,7 +1760,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | req2 = make_request("2", [7] * block_size * 2, block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots( @@ -1775,7 +1775,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # In this case, the ref_cnt of the computed blocks should not be changed. assert manager.block_pool.free_block_queue.num_free_blocks == 5 req3 = make_request("3", common_token_ids * 3, block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req3) assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. @@ -1810,7 +1810,7 @@ def test_reset_prefix_cache(): unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids, block_size, sha256) - computed_blocks, _ = manager.get_computed_blocks(req1) + computed_blocks, *_ = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert len(computed_blocks.blocks[0]) == 3 blocks = manager.allocate_slots( @@ -1845,7 +1845,7 @@ def test_prefix_cache_stats_disabled(): # Call all functions that check whether log_stats is disabled. req = make_request("0", list(range(16)), block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots( @@ -2253,7 +2253,7 @@ def test_eagle_enabled_removes_last_block(): req = make_request("divisible_request", token_ids, block_size, sha256) # Prime the cache - computed_blocks, _ = manager.get_computed_blocks(req) + computed_blocks, *_ = manager.get_computed_blocks(req) manager.allocate_slots( req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks ) @@ -2261,7 +2261,7 @@ def test_eagle_enabled_removes_last_block(): # New request with same tokens + Eagle enabled req_eagle = make_request("eagle_divisible", token_ids, block_size, sha256) - computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) + computed_blocks, num_tokens, *_ = manager.get_computed_blocks(req_eagle) # Should retain 1 block: # 1. Original 3 blocks → pop last hash → 2 matched blocks @@ -2285,7 +2285,7 @@ def test_eagle_with_partial_blocks(): req = make_request("partial_block_test", token_ids, block_size, sha256) # Prime the cache - computed_blocks, _ = manager.get_computed_blocks(req) + computed_blocks, *_ = manager.get_computed_blocks(req) manager.allocate_slots( req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks ) @@ -2293,7 +2293,7 @@ def test_eagle_with_partial_blocks(): # New request with Eagle enabled req_eagle = make_request("partial_eagle", token_ids, block_size, sha256) - computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) + computed_blocks, num_tokens, *_ = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 assert num_tokens == 1 * block_size @@ -2326,7 +2326,7 @@ def test_eagle_with_sliding_window(): req = make_request("partial_block_test", token_ids, block_size, sha256) # Prime the cache - computed_blocks, _ = manager.get_computed_blocks(req) + computed_blocks, *_ = manager.get_computed_blocks(req) manager.allocate_slots( req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks ) @@ -2337,7 +2337,7 @@ def test_eagle_with_sliding_window(): # New request with Eagle enabled req_eagle = make_request("partial_eagle", token_ids, block_size, sha256) - computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) + computed_blocks, num_tokens, *_ = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 assert num_tokens == 1 * block_size @@ -2357,7 +2357,7 @@ def test_eagle_with_sliding_window(): req_after_evict = make_request( "partial_eagle_after_evict", token_ids, block_size, sha256 ) - computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) + computed_blocks, num_tokens, *_ = manager.get_computed_blocks(req_after_evict) # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, # there will be no matched prefix. @@ -2405,7 +2405,7 @@ def test_different_block_size(): common_token_ids = [i for i in range(10) for _ in range(block_size)] req0 = make_request("0", common_token_ids, block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[1] assert num_computed_tokens == 0 @@ -2414,13 +2414,13 @@ def test_different_block_size(): ) assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11]) req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) assert len(computed_blocks.blocks[0]) == 3 assert len(computed_blocks.blocks[1]) == 6 assert num_computed_tokens == 6 * 16 req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 3 assert len(computed_blocks.blocks[1]) == 6 assert num_computed_tokens == 6 * 16 @@ -2434,7 +2434,7 @@ def test_different_block_size(): manager.block_pool.cached_block_hash_to_block.pop( make_block_hash_with_group_id(req1.block_hashes[5], 1), 10 ) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) assert len(computed_blocks.blocks[0]) == 2 assert len(computed_blocks.blocks[1]) == 4 assert num_computed_tokens == 4 * 16 From f913446981b5464f4de3ace0049511860a248c58 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Fri, 24 Apr 2026 09:42:12 -0400 Subject: [PATCH 09/14] Test for Cache Coordinator and Scheduler Signed-off-by: Stanislaw Wozniak --- tests/v1/core/test_prefix_caching.py | 71 ++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index f0cdef4faf49..d735348b9496 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -4,6 +4,7 @@ import copy from collections.abc import Callable +from types import SimpleNamespace import pytest import torch @@ -31,6 +32,7 @@ init_none_hash, make_block_hash_with_group_id, ) +from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheConfig, @@ -1008,6 +1010,75 @@ def test_prefill_hybrid_model_mamba_align(): manager.free(req0) +def test_hybrid_cache_mamba_align_shared_prefix_detection(): + """Test shared prefix detection heuristic for mamba align cache mode + + HybridKVCacheCoordinator returns num_uncached_common > 0 when a shared + uncached prefix is detected. The heuristic is useful for mamba align cache, + where num_uncached_common is a hint consumed by the scheduler. + + Scheduler consumes the hint in _mamba_block_aligned_split to enforce + scheduling aligned with the common prefix. + """ + block_size = 16 + manager = KVCacheManager( + _make_hybrid_kv_cache_config(block_size, 30, ["full", "mamba_align"]), + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + ) + hash_fn = sha256 + prefix = [i for i in range(3) for _ in range(block_size)] + + # Request: 3 blocks -> block-aligned to ensure caching of the last block: + # - mamba_align caches the last state only + # - the cache manager assigns block hashes to full blocks only + # - this means that we need to schedule a block-aligned request to have state cached + req_0 = make_request("0", prefix, block_size, hash_fn) + computed_blocks, num_computed, num_uncached_common, *_ = ( + manager.get_computed_blocks(req_0) + ) + assert num_computed == 0 # nothing cached yet + assert num_uncached_common == 0 + manager.allocate_slots(req_0, 3 * block_size, 0, computed_blocks) + + # Request: 3 blocks (shared with above) + 7 different tokens + req_1 = make_request("1", prefix + [100] * 7, block_size, hash_fn) + computed_blocks, num_computed, num_uncached_common, *_ = ( + manager.get_computed_blocks(req_1) + ) + assert num_computed == 3 * block_size # we should observe a 3-block cache hit + assert num_uncached_common == 0 + manager.allocate_slots(req_1, 7, 3 * block_size, computed_blocks) + + # Request: 3 blocks, but only 2 blocks shared (replace the last token in 3rd block): + req_2 = make_request("2", prefix[:-1] + [101], block_size, hash_fn) + computed_blocks, num_computed, num_uncached_common, *_ = ( + manager.get_computed_blocks(req_2) + ) + assert num_computed == 0 # mamba_align doesn't cache intermediate blocks + assert num_uncached_common == 2 * block_size # heuristic detects a shared prefix + + # Validate scheduler logic for num_uncached_common_prefix_tokens > 0 + # Create minimal mock with just the needed attributes + mock = SimpleNamespace( + cache_config=SimpleNamespace(block_size=block_size), use_eagle=False + ) + num_new_tokens_adjusted = Scheduler._mamba_block_aligned_split( + self=mock, + request=req_2, + num_new_tokens=3 * block_size, + num_uncached_common_prefix_tokens=num_uncached_common, + ) + assert num_new_tokens_adjusted == 2 * block_size # adjust to the common prefix + + manager.allocate_slots(req_2, 3 * block_size, 0, computed_blocks) + # Cleanup + manager.free(req_0) + manager.free(req_1) + manager.free(req_2) + + def test_prefill_plp(): """Test prefill with APC and some prompt logprobs (plp) requests. From ea7bbad9ba631dffe4baae6fa49860ae7472d2de Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 8 Jun 2026 03:31:31 -0400 Subject: [PATCH 10/14] Changed to attribute per tdoublep's suggestion Signed-off-by: Stanislaw Wozniak --- tests/v1/core/test_prefix_caching.py | 117 +++++++++++++-------------- vllm/v1/core/kv_cache_coordinator.py | 27 +++---- vllm/v1/core/kv_cache_manager.py | 12 +-- vllm/v1/core/sched/scheduler.py | 25 ++++-- vllm/v1/simple_kv_offload/manager.py | 2 +- 5 files changed, 94 insertions(+), 89 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index a7a3991ae944..f6cd0e3e6fc0 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -238,7 +238,7 @@ def test_prefill(hash_fn): unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -268,7 +268,7 @@ def test_prefill(hash_fn): # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 @@ -302,7 +302,7 @@ def test_prefill(hash_fn): # Incomplete 1 block (6 tokens) unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(req2.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 @@ -322,7 +322,7 @@ def test_prefill(hash_fn): # Cache miss and eviction. req3 = make_request("3", [99] * (16 * 10), block_size, hash_fn) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req3) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -364,7 +364,7 @@ def test_prefill_hybrid_model(): unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -400,7 +400,7 @@ def test_prefill_hybrid_model(): unique_token_ids = [3] * 5 all_token_ids = common_token_ids + unique_token_ids req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, 7], [0, 10, 11]) assert num_computed_tokens == 3 * 16 @@ -534,7 +534,7 @@ def test_prefill_hybrid_model_eagle(): unique_token_ids = [6] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == len(all_token_ids) // block_size assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -571,7 +571,7 @@ def test_prefill_hybrid_model_eagle(): unique_token_ids = [6] * 5 all_token_ids = common_token_ids + unique_token_ids req1 = make_request("1", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == num_full_blocks assert computed_blocks.get_block_ids() == ( [1, 2, 3, 4, 5], @@ -707,7 +707,7 @@ def _test_partial_request_hit( req = make_request(request_id, prompt_token_ids, block_size, sha256) for hash_with_group_id in hash_to_evict: manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert len(req.block_hashes) == num_full_blocks assert num_computed_tokens == expect_hit_length * block_size for block_per_group in computed_blocks.blocks: @@ -868,7 +868,7 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]): # First request: no cache hit initially req0 = make_request("0", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] # No cache hit initially @@ -885,7 +885,7 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]): # Second request: should hit cached blocks for common prefix req1 = make_request("1", common_token_ids + [4] * 5, block_size, hash_fn) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should hit cached blocks for all groups assert num_computed_tokens == 3 * block_size @@ -945,7 +945,7 @@ def test_prefill_hybrid_model_combinations_eagle( # First request: no cache hit initially req0 = make_request("0", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == num_full_blocks assert not computed_blocks.blocks[0] # No cache hit initially @@ -961,7 +961,7 @@ def test_prefill_hybrid_model_combinations_eagle( # Second request: should hit cached blocks for common prefix all_token_ids = common_token_ids + [6] * 5 req1 = make_request("1", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should hit cached blocks for all groups assert num_computed_tokens == expect_hit_length * block_size @@ -1014,7 +1014,7 @@ def test_prefill_hybrid_model_mamba_align(): # First request: allocate_slots should not crash with the assertion error # in MambaManager.cache_blocks() when null blocks are present. req0 = make_request("0", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, num_computed_tokens, computed_blocks) @@ -1049,27 +1049,24 @@ def test_hybrid_cache_mamba_align_shared_prefix_detection(): # - the cache manager assigns block hashes to full blocks only # - this means that we need to schedule a block-aligned request to have state cached req_0 = make_request("0", prefix, block_size, hash_fn) - computed_blocks, num_computed, num_uncached_common, *_ = ( - manager.get_computed_blocks(req_0) - ) + computed_blocks, num_computed = manager.get_computed_blocks(req_0) + num_uncached_common = manager.coordinator.num_uncached_common_prefix_tokens assert num_computed == 0 # nothing cached yet assert num_uncached_common == 0 manager.allocate_slots(req_0, 3 * block_size, 0, computed_blocks) # Request: 3 blocks (shared with above) + 7 different tokens req_1 = make_request("1", prefix + [100] * 7, block_size, hash_fn) - computed_blocks, num_computed, num_uncached_common, *_ = ( - manager.get_computed_blocks(req_1) - ) + computed_blocks, num_computed = manager.get_computed_blocks(req_1) + num_uncached_common = manager.coordinator.num_uncached_common_prefix_tokens assert num_computed == 3 * block_size # we should observe a 3-block cache hit assert num_uncached_common == 0 manager.allocate_slots(req_1, 7, 3 * block_size, computed_blocks) # Request: 3 blocks, but only 2 blocks shared (replace the last token in 3rd block): req_2 = make_request("2", prefix[:-1] + [101], block_size, hash_fn) - computed_blocks, num_computed, num_uncached_common, *_ = ( - manager.get_computed_blocks(req_2) - ) + computed_blocks, num_computed = manager.get_computed_blocks(req_2) + num_uncached_common = manager.coordinator.num_uncached_common_prefix_tokens assert num_computed == 0 # mamba_align doesn't cache intermediate blocks assert num_uncached_common == 2 * block_size # heuristic detects a shared prefix @@ -1119,7 +1116,7 @@ def test_prefill_plp(): unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, block_size, hash_fn, prompt_logprobs=5) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1151,7 +1148,7 @@ def test_prefill_plp(): # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 @@ -1186,7 +1183,7 @@ def test_prefill_plp(): req2 = make_request( "2", common_token_ids + unique_token_ids, block_size, hash_fn, prompt_logprobs=5 ) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(req2.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1223,7 +1220,7 @@ def test_decode(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids, block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1281,7 +1278,7 @@ def test_evict(): last_token_id = 5 * 16 + 7 req0 = make_request("0", list(range(last_token_id)), block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1294,7 +1291,7 @@ def test_evict(): req1 = make_request( "1", list(range(last_token_id, last_token_id + 3 * 16)), block_size, sha256 ) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1315,7 +1312,7 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert computed_blocks.get_block_ids() == ([1, 2],) assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots( @@ -1341,7 +1338,7 @@ def test_hash_block_correct_reuse(): # Allocate 1 block and cache it. num_tokens = block_size * 1 req = make_request("0", list(range(num_tokens)), block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1355,7 +1352,7 @@ def test_hash_block_correct_reuse(): # Allocate a new block that's not full, make sure hash info on the # block is cleared. req = make_request("1", list(range(num_tokens - 1)), block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1382,7 +1379,7 @@ def test_computed_blocks_not_evicted(): # Allocate a block and cache it. num_tokens = block_size * 1 req0 = make_request("0", list(range(num_tokens)), block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1395,7 +1392,7 @@ def test_computed_blocks_not_evicted(): req1 = make_request( "1", list(range(num_tokens, num_tokens * 2)), block_size, sha256 ) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1411,7 +1408,7 @@ def test_computed_blocks_not_evicted(): # Now if we have a cache hit on the first block, we should evict the second # cached block rather than the first one. req2 = make_request("2", list(range(num_tokens * 2)), block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 1 assert computed_blocks.blocks[0][0].block_id == 1 assert num_computed_tokens == block_size @@ -1442,7 +1439,7 @@ def test_basic_prefix_caching_disabled(): "1", list(range(10)), block_size, sha256 ) # 2 blocks and some more - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1455,7 +1452,7 @@ def test_basic_prefix_caching_disabled(): # No caching. req2 = make_request("2", list(range(16)), block_size, sha256) # shared prefix - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1465,7 +1462,7 @@ def test_basic_prefix_caching_disabled(): # New requests should not have any blocks. req3 = make_request("3", list(range(4)), block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req3) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots( @@ -1647,7 +1644,7 @@ def test_mm_prefix_caching(): mm_positions=mm_positions, mm_hashes=mm_hashes, ) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes assert not computed_blocks.blocks[0] @@ -1712,7 +1709,7 @@ def test_mm_prefix_caching(): mm_positions=mm_positions, mm_hashes=mm_hashes, ) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(computed_blocks.blocks[0]) == 3 assert num_computed_tokens == 3 * 16 @@ -1734,7 +1731,7 @@ def test_cache_key_salting(): common_token_ids = [i for i in range(3) for _ in range(block_size)] token_ids = common_token_ids + [3] * 11 req0 = make_request("0", token_ids, block_size, sha256, cache_salt="salt1") - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes assert not computed_blocks.blocks[0] @@ -1773,7 +1770,7 @@ def test_cache_key_salting(): # Test cache hit with a new request that has the same salt. token_ids = common_token_ids + [4] * 11 req1 = make_request("1", token_ids, block_size, sha256, cache_salt="salt1") - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. assert len(computed_blocks.blocks[0]) == 3 assert num_computed_tokens == 3 * block_size @@ -1781,7 +1778,7 @@ def test_cache_key_salting(): # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 req2 = make_request("2", token_ids, block_size, sha256, cache_salt="salt2") - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 0 assert num_computed_tokens == 0 block_hashes = req2.block_hashes @@ -1815,7 +1812,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | ... | common_token_ids = [i for i in range(3) for _ in range(16)] req0 = make_request("0", common_token_ids, block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots( @@ -1827,7 +1824,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2, block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 manager.allocate_slots( @@ -1845,7 +1842,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | req2 = make_request("2", [7] * block_size * 2, block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots( @@ -1860,7 +1857,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # In this case, the ref_cnt of the computed blocks should not be changed. assert manager.block_pool.free_block_queue.num_free_blocks == 5 req3 = make_request("3", common_token_ids * 3, block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req3) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. @@ -1895,7 +1892,7 @@ def test_reset_prefix_cache(): unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids, block_size, sha256) - computed_blocks, *_ = manager.get_computed_blocks(req1) + computed_blocks = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert len(computed_blocks.blocks[0]) == 3 blocks = manager.allocate_slots( @@ -1930,7 +1927,7 @@ def test_prefix_cache_stats_disabled(): # Call all functions that check whether log_stats is disabled. req = make_request("0", list(range(16)), block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots( @@ -2425,7 +2422,7 @@ def test_eagle_enabled_removes_last_block(): req = make_request("divisible_request", token_ids, block_size, sha256) # Prime the cache - computed_blocks, *_ = manager.get_computed_blocks(req) + computed_blocks = manager.get_computed_blocks(req) manager.allocate_slots( req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks ) @@ -2433,7 +2430,7 @@ def test_eagle_enabled_removes_last_block(): # New request with same tokens + Eagle enabled req_eagle = make_request("eagle_divisible", token_ids, block_size, sha256) - computed_blocks, num_tokens, *_ = manager.get_computed_blocks(req_eagle) + computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Should retain 1 block: # 1. Original 3 blocks → pop last hash → 2 matched blocks @@ -2457,7 +2454,7 @@ def test_eagle_with_partial_blocks(): req = make_request("partial_block_test", token_ids, block_size, sha256) # Prime the cache - computed_blocks, *_ = manager.get_computed_blocks(req) + computed_blocks = manager.get_computed_blocks(req) manager.allocate_slots( req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks ) @@ -2465,7 +2462,7 @@ def test_eagle_with_partial_blocks(): # New request with Eagle enabled req_eagle = make_request("partial_eagle", token_ids, block_size, sha256) - computed_blocks, num_tokens, *_ = manager.get_computed_blocks(req_eagle) + computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 assert num_tokens == 1 * block_size @@ -2498,7 +2495,7 @@ def test_eagle_with_sliding_window(): req = make_request("partial_block_test", token_ids, block_size, sha256) # Prime the cache - computed_blocks, *_ = manager.get_computed_blocks(req) + computed_blocks = manager.get_computed_blocks(req) manager.allocate_slots( req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks ) @@ -2509,7 +2506,7 @@ def test_eagle_with_sliding_window(): # New request with Eagle enabled req_eagle = make_request("partial_eagle", token_ids, block_size, sha256) - computed_blocks, num_tokens, *_ = manager.get_computed_blocks(req_eagle) + computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 assert num_tokens == 1 * block_size @@ -2529,7 +2526,7 @@ def test_eagle_with_sliding_window(): req_after_evict = make_request( "partial_eagle_after_evict", token_ids, block_size, sha256 ) - computed_blocks, num_tokens, *_ = manager.get_computed_blocks(req_after_evict) + computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, # there will be no matched prefix. @@ -2772,7 +2769,7 @@ def test_different_block_size(): common_token_ids = [i for i in range(10) for _ in range(block_size)] req0 = make_request("0", common_token_ids, block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req0) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[1] assert num_computed_tokens == 0 @@ -2781,13 +2778,13 @@ def test_different_block_size(): ) assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11]) req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(computed_blocks.blocks[0]) == 3 assert len(computed_blocks.blocks[1]) == 6 assert num_computed_tokens == 6 * 16 req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 3 assert len(computed_blocks.blocks[1]) == 6 assert num_computed_tokens == 6 * 16 @@ -2801,7 +2798,7 @@ def test_different_block_size(): manager.block_pool.cached_block_hash_to_block.pop( make_block_hash_with_group_id(req1.block_hashes[5], 1), 10 ) - computed_blocks, num_computed_tokens, *_ = manager.get_computed_blocks(req1) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(computed_blocks.blocks[0]) == 2 assert len(computed_blocks.blocks[1]) == 4 assert num_computed_tokens == 4 * 16 diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 447fcf4f60d2..192095d84a10 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -273,7 +273,7 @@ def find_longest_cache_hit( self, block_hashes: list[BlockHash], max_cache_hit_length: int, - ) -> tuple[tuple[list[KVCacheBlock], ...], int, int]: + ) -> tuple[tuple[list[KVCacheBlock], ...], int]: pass def new_step_starts(self) -> None: @@ -325,11 +325,11 @@ def find_longest_cache_hit( self, block_hashes: list[BlockHash], max_cache_hit_length: int, - ) -> tuple[tuple[list[KVCacheBlock], ...], int, int]: + ) -> tuple[tuple[list[KVCacheBlock], ...], int]: blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(self.num_single_type_manager) ) - return blocks, 0, 0 + return blocks, 0 class UnitaryKVCacheCoordinator(KVCacheCoordinator): @@ -389,7 +389,7 @@ def find_longest_cache_hit( self, block_hashes: list[BlockHash], max_cache_hit_length: int, - ) -> tuple[tuple[list[KVCacheBlock], ...], int, int]: + ) -> tuple[tuple[list[KVCacheBlock], ...], int]: hit_blocks = self.single_type_managers[0].find_longest_cache_hit( block_hashes=block_hashes, max_length=max_cache_hit_length, @@ -401,7 +401,7 @@ def find_longest_cache_hit( dcp_world_size=self.dcp_world_size, pcp_world_size=self.pcp_world_size, ) - return hit_blocks, len(hit_blocks[0]) * self.block_size, 0 + return hit_blocks, len(hit_blocks[0]) * self.block_size class SpecGroup(NamedTuple): @@ -533,7 +533,7 @@ def find_longest_cache_hit( self, block_hashes: list[BlockHash], max_cache_hit_length: int, - ) -> tuple[tuple[list[KVCacheBlock], ...], int, int]: + ) -> tuple[tuple[list[KVCacheBlock], ...], int]: """ Find the longest cache hit using an iterative fixed-point algorithm. @@ -640,15 +640,12 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: # Uncached shared prefix detection heuristic: # If any attention group cached a longer prefix than the current common # prefix, there was a request with a shared prefix of at least that - # length in the past. Return the length of such common prefix: - num_uncached_common_prefix_tokens = longest_hit_length - hit_length - return ( - tuple( - blocks if blocks is not None else [] for blocks in hit_blocks_by_group - ), - hit_length, - num_uncached_common_prefix_tokens, - ) + # length in the past. Return the length of such common prefix. + # Implementation: Use an attribute to avoid function return signature change. + self.num_uncached_common_prefix_tokens = longest_hit_length - hit_length + return tuple( + blocks if blocks is not None else [] for blocks in hit_blocks_by_group + ), hit_length def get_kv_cache_coordinator( diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 5ee51a7e8a68..659130d376d4 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -193,7 +193,7 @@ def make_prefix_cache_stats(self) -> PrefixCacheStats | None: self.prefix_cache_stats = PrefixCacheStats() return stats - def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int, int]: + def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -213,7 +213,7 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int, int # (which happens when the request requires prompt logprobs # or calls a pooling model with all pooling). if not self.enable_caching or request.skip_reading_prefix_cache: - return self.empty_kv_cache_blocks, 0, 0 + return self.empty_kv_cache_blocks, 0 # NOTE: When all tokens hit the cache, we must recompute the last token # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. @@ -222,7 +222,7 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int, int # num_computed_tokens to be block-size aligned. Removing this limitation # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 - computed_blocks, num_new_computed_tokens, num_uncached_common_prefix_tokens = ( + computed_blocks, num_new_computed_tokens = ( self.coordinator.find_longest_cache_hit( request.block_hashes, max_cache_hit_length ) @@ -236,11 +236,7 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int, int preempted=request.num_preemptions > 0, ) - return ( - self.create_kv_cache_blocks(computed_blocks), - num_new_computed_tokens, - num_uncached_common_prefix_tokens, - ) + return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens def allocate_slots( self, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 9b2a0d6125d6..5f5c2ca79045 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -621,11 +621,26 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. if request.num_computed_tokens == 0: # Get locally-cached tokens. - ( - new_computed_blocks, - num_new_local_computed_tokens, - num_uncached_common_prefix_tokens, - ) = self.kv_cache_manager.get_computed_blocks(request) + new_computed_blocks, num_new_local_computed_tokens = ( + self.kv_cache_manager.get_computed_blocks(request) + ) + + # In case of hybrid models, obtain hint for Marconi-style APC logic + # More proper check for hybrid model case would be: + # if isinstance(self.kv_cache_manager.coordinator, + # HybridKVCacheCoordinator): + # but the check below is similar and avoids + # importing HybridKVCacheCoordinator: + if self.has_mamba_layers: + # obtain num_uncached_common_prefix_tokens from coordinator's + # attribute if exists, else 0 + # (alternative solution: pass as a return value, but + # this would require multiple function signature changes) + num_uncached_common_prefix_tokens = getattr( + self.kv_cache_manager.coordinator, + "num_uncached_common_prefix_tokens", + 0, + ) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: diff --git a/vllm/v1/simple_kv_offload/manager.py b/vllm/v1/simple_kv_offload/manager.py index 869b6588e9eb..f61c4320dffd 100644 --- a/vllm/v1/simple_kv_offload/manager.py +++ b/vllm/v1/simple_kv_offload/manager.py @@ -247,7 +247,7 @@ def get_num_new_matched_tokens( max_hit_len = request.num_tokens - 1 - num_computed_tokens if max_hit_len <= 0: return 0, False - cpu_hit_blocks, hit_length, _ = self.cpu_coordinator.find_longest_cache_hit( + cpu_hit_blocks, hit_length = self.cpu_coordinator.find_longest_cache_hit( remaining_hashes, max_hit_len ) From 73830b56b45691110f5f18ce21bf8d6ac090e1c7 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 8 Jun 2026 03:48:24 -0400 Subject: [PATCH 11/14] Small fixes. Signed-off-by: Stanislaw Wozniak --- tests/v1/core/test_prefix_caching.py | 10 +++++----- vllm/v1/core/kv_cache_manager.py | 3 --- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index f6cd0e3e6fc0..ac1eed59b001 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1035,7 +1035,7 @@ def test_hybrid_cache_mamba_align_shared_prefix_detection(): scheduling aligned with the common prefix. """ block_size = 16 - manager = KVCacheManager( + manager = make_kv_cache_manager( _make_hybrid_kv_cache_config(block_size, 30, ["full", "mamba_align"]), max_model_len=8192, enable_caching=True, @@ -1892,7 +1892,7 @@ def test_reset_prefix_cache(): unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids, block_size, sha256) - computed_blocks = manager.get_computed_blocks(req1) + computed_blocks, _ = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert len(computed_blocks.blocks[0]) == 3 blocks = manager.allocate_slots( @@ -2422,7 +2422,7 @@ def test_eagle_enabled_removes_last_block(): req = make_request("divisible_request", token_ids, block_size, sha256) # Prime the cache - computed_blocks = manager.get_computed_blocks(req) + computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots( req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks ) @@ -2454,7 +2454,7 @@ def test_eagle_with_partial_blocks(): req = make_request("partial_block_test", token_ids, block_size, sha256) # Prime the cache - computed_blocks = manager.get_computed_blocks(req) + computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots( req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks ) @@ -2495,7 +2495,7 @@ def test_eagle_with_sliding_window(): req = make_request("partial_block_test", token_ids, block_size, sha256) # Prime the cache - computed_blocks = manager.get_computed_blocks(req) + computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots( req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 659130d376d4..d98520da95fc 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -204,9 +204,6 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: A tuple containing: - A list of blocks that are computed for the request. - The number of computed tokens. - - 0, or the number of uncached shared prefix tokens - beyond computed tokens, as detected by shared prefix detection - (currently implemented in HybridKVCacheCoordinator) """ # We skip finding the prefix cache hit when prefix caching is # disabled or the request is marked as skipping kv cache read From 376ade3915891312984c8bfae2dcb907d775ed32 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Tue, 9 Jun 2026 03:32:33 -0400 Subject: [PATCH 12/14] Less verbose comments Signed-off-by: Stanislaw Wozniak --- tests/v1/core/test_prefix_caching.py | 16 +++++---------- vllm/v1/core/kv_cache_coordinator.py | 9 ++------- vllm/v1/core/sched/scheduler.py | 29 +++------------------------- 3 files changed, 10 insertions(+), 44 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 8d9d73100c1f..366cd518557c 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1029,11 +1029,8 @@ def test_hybrid_cache_mamba_align_shared_prefix_detection(): """Test shared prefix detection heuristic for mamba align cache mode HybridKVCacheCoordinator returns num_uncached_common > 0 when a shared - uncached prefix is detected. The heuristic is useful for mamba align cache, - where num_uncached_common is a hint consumed by the scheduler. - - Scheduler consumes the hint in _mamba_block_aligned_split to enforce - scheduling aligned with the common prefix. + uncached prefix is detected. With mamba_align cache, _mamba_block_aligned_split + enforces scheduling aligned with the common prefix. """ block_size = 16 manager = make_kv_cache_manager( @@ -1043,12 +1040,9 @@ def test_hybrid_cache_mamba_align_shared_prefix_detection(): hash_block_size=block_size, ) hash_fn = sha256 - prefix = [i for i in range(3) for _ in range(block_size)] - # Request: 3 blocks -> block-aligned to ensure caching of the last block: - # - mamba_align caches the last state only - # - the cache manager assigns block hashes to full blocks only - # - this means that we need to schedule a block-aligned request to have state cached + # Request: 3 blocks + prefix = [i for i in range(3) for _ in range(block_size)] req_0 = make_request("0", prefix, block_size, hash_fn) computed_blocks, num_computed = manager.get_computed_blocks(req_0) num_uncached_common = manager.coordinator.num_uncached_common_prefix_tokens @@ -1071,7 +1065,7 @@ def test_hybrid_cache_mamba_align_shared_prefix_detection(): assert num_computed == 0 # mamba_align doesn't cache intermediate blocks assert num_uncached_common == 2 * block_size # heuristic detects a shared prefix - # Validate scheduler logic for num_uncached_common_prefix_tokens > 0 + # Next, validate scheduler logic for num_uncached_common_prefix_tokens > 0 # Create minimal mock with just the needed attributes mock = SimpleNamespace( cache_config=SimpleNamespace(block_size=block_size), use_eagle=False diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index bd484b5745df..2461f035dbd5 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -669,8 +669,6 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: hit_blocks_by_group[group_id] = blocks # Collect information on the longest cached prefix overall - # (no matter the attention type) to allow for more complex - # caching policies longest_hit_length = max(longest_hit_length, curr_hit_length) if curr_hit_length >= hit_length: @@ -687,11 +685,8 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: if (blks := hit_blocks_by_group[group_id]) is not None: del blks[num_blocks:] - # Uncached shared prefix detection heuristic: - # If any attention group cached a longer prefix than the current common - # prefix, there was a request with a shared prefix of at least that - # length in the past. Return the length of such common prefix. - # Implementation: Use an attribute to avoid function return signature change. + # Uncached shared prefix detection: If any attn. group cached a longer prefix + # than the current prefix, it is an uncached common prefix across requests: self.num_uncached_common_prefix_tokens = longest_hit_length - hit_length return tuple( blocks if blocks is not None else [] for blocks in hit_blocks_by_group diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4d3846cf6dfd..478257d83f8d 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -338,28 +338,14 @@ def _mamba_block_aligned_split( pass # Marconi cache admission optimization: - # Create cache entries at divergence points of common prefixes. - # - # Implementation: - # If uncached common prefix (num_uncached_common_prefix_tokens) - # is long enough to justify its caching ( >= block_size) - # AND - # currently scheduled token count is longer than the common prefix + # cache common prefixes by scheduling num_new_tokens = common prefix length if ( num_uncached_common_prefix_tokens >= block_size and num_new_tokens > num_uncached_common_prefix_tokens ): - # Then force to cache at the end of the common prefix - # by limiting the num_new_tokens to the length of that prefix: num_new_tokens = num_uncached_common_prefix_tokens - # This should be still block aligned as: - # - token hit counts are block aligned - # - thus num_uncached_common_prefix_tokens is block aligned - # - attention and mamba block sizes are equal - # Optionally, we can verify this: - assert num_new_tokens % block_size == 0 - # Or force block re-alignment: - # num_new_tokens = num_new_tokens // block_size * block_size + # keep alignment to block_size + num_new_tokens = num_new_tokens // block_size * block_size return num_new_tokens def schedule(self) -> SchedulerOutput: @@ -638,16 +624,7 @@ def schedule(self) -> SchedulerOutput: ) # In case of hybrid models, obtain hint for Marconi-style APC logic - # More proper check for hybrid model case would be: - # if isinstance(self.kv_cache_manager.coordinator, - # HybridKVCacheCoordinator): - # but the check below is similar and avoids - # importing HybridKVCacheCoordinator: if self.has_mamba_layers: - # obtain num_uncached_common_prefix_tokens from coordinator's - # attribute if exists, else 0 - # (alternative solution: pass as a return value, but - # this would require multiple function signature changes) num_uncached_common_prefix_tokens = getattr( self.kv_cache_manager.coordinator, "num_uncached_common_prefix_tokens", From feb75632041951bad2bcfaa7252080a965f48208 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Tue, 9 Jun 2026 03:37:12 -0400 Subject: [PATCH 13/14] Less verbose comments Signed-off-by: Stanislaw Wozniak --- vllm/v1/core/kv_cache_coordinator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 2461f035dbd5..56150142bf87 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -668,7 +668,6 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: for group_id, blocks in zip(group_ids, hit_blocks): hit_blocks_by_group[group_id] = blocks - # Collect information on the longest cached prefix overall longest_hit_length = max(longest_hit_length, curr_hit_length) if curr_hit_length >= hit_length: From 99640b35af9df28319fdca223156531d9fcb6772 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Tue, 9 Jun 2026 03:52:57 -0400 Subject: [PATCH 14/14] Initialize value Signed-off-by: Stanislaw Wozniak --- vllm/v1/core/sched/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 478257d83f8d..e61b9991b210 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -615,6 +615,7 @@ def schedule(self) -> SchedulerOutput: num_external_computed_tokens = 0 load_kv_async = False connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0 + num_uncached_common_prefix_tokens = 0 # Get already-cached tokens. if request.num_computed_tokens == 0: