diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 8f334ce9ac0e..366cd518557c 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -5,6 +5,7 @@ import copy from collections.abc import Callable from math import lcm +from types import SimpleNamespace import pytest import torch @@ -33,6 +34,7 @@ init_none_hash, make_block_hash_with_group_id, ) +from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheConfig, @@ -1023,6 +1025,66 @@ def test_prefill_hybrid_model_mamba_align(): manager.free(req0) +def test_hybrid_cache_mamba_align_shared_prefix_detection(): + """Test shared prefix detection heuristic for mamba align cache mode + + HybridKVCacheCoordinator returns num_uncached_common > 0 when a shared + uncached prefix is detected. With mamba_align cache, _mamba_block_aligned_split + enforces scheduling aligned with the common prefix. + """ + block_size = 16 + manager = make_kv_cache_manager( + _make_hybrid_kv_cache_config(block_size, 30, ["full", "mamba_align"]), + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + ) + hash_fn = sha256 + + # Request: 3 blocks + prefix = [i for i in range(3) for _ in range(block_size)] + req_0 = make_request("0", prefix, block_size, hash_fn) + computed_blocks, num_computed = manager.get_computed_blocks(req_0) + num_uncached_common = manager.coordinator.num_uncached_common_prefix_tokens + assert num_computed == 0 # nothing cached yet + assert num_uncached_common == 0 + manager.allocate_slots(req_0, 3 * block_size, 0, computed_blocks) + + # Request: 3 blocks (shared with above) + 7 different tokens + req_1 = make_request("1", prefix + [100] * 7, block_size, hash_fn) + computed_blocks, num_computed = manager.get_computed_blocks(req_1) + num_uncached_common = manager.coordinator.num_uncached_common_prefix_tokens + assert num_computed == 3 * block_size # we should observe a 3-block cache hit + assert num_uncached_common == 0 + manager.allocate_slots(req_1, 7, 3 * block_size, computed_blocks) + + # Request: 3 blocks, but only 2 blocks shared (replace the last token in 3rd block): + req_2 = make_request("2", prefix[:-1] + [101], block_size, hash_fn) + computed_blocks, num_computed = manager.get_computed_blocks(req_2) + num_uncached_common = manager.coordinator.num_uncached_common_prefix_tokens + assert num_computed == 0 # mamba_align doesn't cache intermediate blocks + assert num_uncached_common == 2 * block_size # heuristic detects a shared prefix + + # Next, validate scheduler logic for num_uncached_common_prefix_tokens > 0 + # Create minimal mock with just the needed attributes + mock = SimpleNamespace( + cache_config=SimpleNamespace(block_size=block_size), use_eagle=False + ) + num_new_tokens_adjusted = Scheduler._mamba_block_aligned_split( + self=mock, + request=req_2, + num_new_tokens=3 * block_size, + num_uncached_common_prefix_tokens=num_uncached_common, + ) + assert num_new_tokens_adjusted == 2 * block_size # adjust to the common prefix + + manager.allocate_slots(req_2, 3 * block_size, 0, computed_blocks) + # Cleanup + manager.free(req_0) + manager.free(req_1) + manager.free(req_2) + + def test_hybrid_model_mamba_align_with_dynamic_draft_tokens(): """Regression test for https://github.com/vllm-project/vllm/issues/39271. diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 89b1e84a44e7..56150142bf87 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -611,6 +611,7 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: num_groups = len(self.kv_cache_config.kv_cache_groups) hit_length = max_cache_hit_length + longest_hit_length = 0 hit_blocks_by_group: list[list[KVCacheBlock] | None] = [None] * num_groups # Simple hybrid (1 full attn + 1 other): one iteration suffices. @@ -667,6 +668,8 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: for group_id, blocks in zip(group_ids, hit_blocks): hit_blocks_by_group[group_id] = blocks + longest_hit_length = max(longest_hit_length, curr_hit_length) + if curr_hit_length >= hit_length: break hit_length = curr_hit_length @@ -681,6 +684,9 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: if (blks := hit_blocks_by_group[group_id]) is not None: del blks[num_blocks:] + # Uncached shared prefix detection: If any attn. group cached a longer prefix + # than the current prefix, it is an uncached common prefix across requests: + self.num_uncached_common_prefix_tokens = longest_hit_length - hit_length return tuple( blocks if blocks is not None else [] for blocks in hit_blocks_by_group ), hit_length diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 889232c3e4d2..e61b9991b210 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -296,6 +296,7 @@ def _mamba_block_aligned_split( num_new_tokens: int, num_new_local_computed_tokens: int = 0, num_external_computed_tokens: int = 0, + num_uncached_common_prefix_tokens: int = 0, ) -> int: num_computed_tokens = ( request.num_computed_tokens @@ -335,6 +336,16 @@ def _mamba_block_aligned_split( else: # prefill the last few tokens pass + + # Marconi cache admission optimization: + # cache common prefixes by scheduling num_new_tokens = common prefix length + if ( + num_uncached_common_prefix_tokens >= block_size + and num_new_tokens > num_uncached_common_prefix_tokens + ): + num_new_tokens = num_uncached_common_prefix_tokens + # keep alignment to block_size + num_new_tokens = num_new_tokens // block_size * block_size return num_new_tokens def schedule(self) -> SchedulerOutput: @@ -604,6 +615,7 @@ def schedule(self) -> SchedulerOutput: num_external_computed_tokens = 0 load_kv_async = False connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0 + num_uncached_common_prefix_tokens = 0 # Get already-cached tokens. if request.num_computed_tokens == 0: @@ -612,6 +624,14 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.get_computed_blocks(request) ) + # In case of hybrid models, obtain hint for Marconi-style APC logic + if self.has_mamba_layers: + num_uncached_common_prefix_tokens = getattr( + self.kv_cache_manager.coordinator, + "num_uncached_common_prefix_tokens", + 0, + ) + # Get externally-cached tokens if using a KVConnector. if self.connector is not None: ext_tokens, load_kv_async = ( @@ -724,6 +744,7 @@ def schedule(self) -> SchedulerOutput: num_new_tokens, num_new_local_computed_tokens, num_external_computed_tokens, + num_uncached_common_prefix_tokens, ) if num_new_tokens == 0: break