From 6726d3ee28c55c03fa32c3c37c7d3c866fcefc30 Mon Sep 17 00:00:00 2001 From: bofengluo Date: Thu, 1 Jan 2026 00:25:26 +0000 Subject: [PATCH] feat(kv-cache): support multiple sliding window groups in HybridKVCacheCoordinator with tests Changes to be committed: new file: tests/v1/core/test_hybrid_kv_cache_coordinator.py modified: vllm/v1/core/kv_cache_coordinator.py Signed-off-by: bofengluo --- .../core/test_hybrid_kv_cache_coordinator.py | 202 +++++++++++++++++ vllm/v1/core/kv_cache_coordinator.py | 205 +++++++++++++++--- 2 files changed, 374 insertions(+), 33 deletions(-) create mode 100644 tests/v1/core/test_hybrid_kv_cache_coordinator.py diff --git a/tests/v1/core/test_hybrid_kv_cache_coordinator.py b/tests/v1/core/test_hybrid_kv_cache_coordinator.py new file mode 100644 index 000000000000..de6606fa4a7d --- /dev/null +++ b/tests/v1/core/test_hybrid_kv_cache_coordinator.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test HybridKVCacheCoordinator with multiple sliding window groups.""" + +import pytest +import torch + +from vllm.sampling_params import SamplingParams +from vllm.utils.hashing import sha256 +from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_utils import ( + get_request_block_hasher, + init_none_hash, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + SlidingWindowSpec, +) +from vllm.v1.request import Request + +pytestmark = pytest.mark.cpu_test + + +@pytest.fixture(autouse=True) +def _auto_init_hash_fn(): + init_none_hash(sha256) + + +def make_request( + request_id: str, + prompt_token_ids: list[int], + block_size: int, +): + return Request( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + mm_features=None, + sampling_params=SamplingParams(max_tokens=17), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=None, + block_hasher=get_request_block_hasher(block_size, sha256), + ) + + +def make_kv_cache_config_multi_sliding_window( + block_size: int, + num_blocks: int, + sliding_windows: list[int], +) -> KVCacheConfig: + """ + Create a KVCacheConfig with one full attention group and multiple + sliding window groups with different window sizes. + """ + groups = [ + KVCacheGroupSpec( + ["full_attn_layer"], + FullAttentionSpec(block_size, 1, 1, torch.float32), + ) + ] + for i, sw in enumerate(sliding_windows): + groups.append( + KVCacheGroupSpec( + [f"sw_layer_{i}"], + SlidingWindowSpec(block_size, 1, 1, torch.float32, sliding_window=sw), + ) + ) + return KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=groups, + ) + + +class TestHybridKVCacheCoordinatorMultipleSlidingWindows: + def test_verify_and_sort_multiple_sliding_windows(self): + block_size = 16 + kv_cache_config = make_kv_cache_config_multi_sliding_window( + block_size=block_size, + num_blocks=100, + sliding_windows=[256, 512, 128], + ) + + manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + ) + + coordinator = manager.coordinator + + assert coordinator.full_attention_group_ids == [0] + + # Verify sliding window groups are sorted by window size (descending) + # Original: [1, 2, 3] with windows [256, 512, 128] + # Sorted: [2, 1, 3] with windows [512, 256, 128] + assert coordinator.sliding_window_group_ids == [2, 1, 3] + + window_sizes = [ + spec.sliding_window for spec in coordinator.sliding_window_specs + ] + assert window_sizes == [512, 256, 128], ( + "Specs should be sorted by window size descending" + ) + + def test_cache_hit_multiple_sliding_windows(self): + block_size = 16 + kv_cache_config = make_kv_cache_config_multi_sliding_window( + block_size=block_size, + num_blocks=100, + sliding_windows=[512, 256], + ) + + manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + ) + common_token_ids = [i for i in range(6) for _ in range(block_size)] + req0 = make_request("0", common_token_ids, block_size) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + assert not computed_blocks.blocks[0] + assert num_computed_tokens == 0 + + blocks = manager.allocate_slots( + req0, + len(common_token_ids), + len(computed_blocks.blocks[0]) * block_size, + computed_blocks, + ) + assert blocks is not None + + manager.free(req0) + req1 = make_request("1", common_token_ids, block_size) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + + assert num_computed_tokens == 80 + assert len(computed_blocks.blocks[0]) == 5 + assert len(computed_blocks.blocks[1]) == 5 + assert len(computed_blocks.blocks[2]) == 5 + + def test_partial_cache_hit_different_sliding_windows(self): + block_size = 16 + kv_cache_config = make_kv_cache_config_multi_sliding_window( + block_size=block_size, + num_blocks=50, + sliding_windows=[64, 32], + ) + + manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + ) + + common_token_ids = [i for i in range(5) for _ in range(block_size)] + + req0 = make_request("0", common_token_ids, block_size) + computed_blocks, _ = manager.get_computed_blocks(req0) + manager.allocate_slots( + req0, + len(common_token_ids), + len(computed_blocks.blocks[0]) * block_size, + computed_blocks, + ) + + block_hashes = req0.block_hashes + assert len(block_hashes) == 5 + + manager.free(req0) + + req1 = make_request("1", common_token_ids, block_size) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + + assert num_computed_tokens == 64 + assert len(computed_blocks.blocks[0]) == 4 + assert len(computed_blocks.blocks[1]) == 4 + assert len(computed_blocks.blocks[2]) == 4 + manager.free(req1) + + from vllm.v1.core.kv_cache_utils import make_block_hash_with_group_id + + # Evict block[1] from SW-32 (group 2) + # group 1 = SW-64, group 2 = SW-32 + hash_to_evict = make_block_hash_with_group_id(block_hashes[2], 2) + manager.block_pool.cached_block_hash_to_block._cache.pop(hash_to_evict, None) + + req2 = make_request("2", common_token_ids, block_size) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + + assert num_computed_tokens == 32, f"Expected 16, got {num_computed_tokens}" + assert len(computed_blocks.blocks[0]) == 2, ( + "Full Attention should have 3 blocks" + ) + assert len(computed_blocks.blocks[1]) == 2, "SW-64 should have 3 blocks" + assert len(computed_blocks.blocks[2]) == 2, "SW-32 should have 3 blocks" diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 1d00873e6062..bb70aca820c3 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -15,12 +15,14 @@ from vllm.v1.core.single_type_kv_cache_manager import ( CrossAttentionManager, FullAttentionManager, + SlidingWindowManager, get_manager_for_kv_cache_spec, ) from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheConfig, KVCacheSpec, + SlidingWindowSpec, ) from vllm.v1.request import Request @@ -354,9 +356,8 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): """ KV cache coordinator for hybrid models with multiple KV cache types, and thus multiple kv cache groups. - To simplify `find_longest_cache_hit`, it only supports the combination of - two types of KV cache groups, and one of them must be full attention. - May extend to more general cases in the future. + Supports full attention + multiple sliding window groups with different + window sizes. """ def __init__( @@ -397,14 +398,16 @@ def __init__( def verify_and_split_kv_cache_groups(self) -> None: """ - Verifies that the model has exactly two types of KV cache groups, and - one of them is full attention. Then, split the kv cache groups into full - attention groups and other groups. + Verifies that the model has full attention groups and one or more + other attention groups (e.g., sliding window with different sizes). """ full_attention_spec: FullAttentionSpec | None = None other_spec: KVCacheSpec | None = None self.full_attention_group_ids: list[int] = [] self.other_group_ids: list[int] = [] + self.sliding_window_group_ids: list[int] = [] + self.sliding_window_specs: list[SlidingWindowSpec] = [] + for i, g in enumerate(self.kv_cache_config.kv_cache_groups): if isinstance(g.kv_cache_spec, FullAttentionSpec): if full_attention_spec is None: @@ -415,6 +418,9 @@ def verify_and_split_kv_cache_groups(self) -> None: "full attention groups now." ) self.full_attention_group_ids.append(i) + elif isinstance(g.kv_cache_spec, SlidingWindowSpec): + self.sliding_window_group_ids.append(i) + self.sliding_window_specs.append(g.kv_cache_spec) else: if other_spec is None: other_spec = g.kv_cache_spec @@ -429,39 +435,90 @@ def verify_and_split_kv_cache_groups(self) -> None: "HybridKVCacheCoordinator assumes exactly one type of full " "attention groups now." ) - assert other_spec is not None, ( - "HybridKVCacheCoordinator assumes exactly one type of other groups now." + has_sliding_window = len(self.sliding_window_group_ids) > 0 + has_other = other_spec is not None + + assert has_sliding_window or has_other, ( + "HybridKVCacheCoordinator requires at least one sliding window group " + "or one other attention group." + ) + assert not (has_sliding_window and has_other), ( + "HybridKVCacheCoordinator does not support mixing sliding window groups " + "with other attention types. Use either sliding windows OR other type." ) - self.full_attention_manager_cls = FullAttentionManager - self.other_attention_cls = self.single_type_managers[ - self.other_group_ids[0] - ].__class__ self.full_attention_spec = full_attention_spec - self.other_spec = other_spec - self.full_attention_block_size = self.full_attention_spec.block_size - self.other_block_size = self.other_spec.block_size - # The LCM of the block sizes of full attention and other attention. - # The cache hit length must be a multiple of the LCM of the block sizes - # to make sure the cache hit length is a multiple of the block size of - # each attention type. Requiring this because we don't support partial - # block cache hit yet. - self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size) - - if max(self.full_attention_group_ids) < min(self.other_group_ids): - self.full_attn_first = True - elif max(self.other_group_ids) < min(self.full_attention_group_ids): - self.full_attn_first = False + self.full_attention_block_size = full_attention_spec.block_size + self.full_attention_manager_cls = FullAttentionManager + + if len(self.sliding_window_group_ids) == 0: + self.other_spec = other_spec + self.other_attention_cls = self.single_type_managers[ + self.other_group_ids[0] + ].__class__ + assert self.other_spec is not None + self.other_block_size = self.other_spec.block_size + # The LCM of the block sizes of full attention and other attention. + # The cache hit length must be a multiple of the LCM of the block sizes + # to make sure the cache hit length is a multiple of the block size of + # each attention type. Requiring this because we don't support partial + # block cache hit yet. + self.lcm_block_size = lcm( + self.full_attention_block_size, self.other_block_size + ) + + if max(self.full_attention_group_ids) < min(self.other_group_ids): + self.full_attn_first = True + elif max(self.other_group_ids) < min(self.full_attention_group_ids): + self.full_attn_first = False + else: + raise ValueError( + "HybridKVCacheCoordinator assumes the full " + "attention group ids and other attention group ids " + "do not interleave, either full attention group ids " + "are before other attention group ids or vice versa." + "This is for simplifying merging hit_blocks_full_attn and " + "hit_blocks_other_attn to hit_blocks." + ) else: - raise ValueError( - "HybridKVCacheCoordinator assumes the full " - "attention group ids and other attention group ids " - "do not interleave, either full attention group ids " - "are before other attention group ids or vice versa." - "This is for simplifying merging hit_blocks_full_attn and " - "hit_blocks_other_attn to hit_blocks." + self.full_attention_manager_cls = FullAttentionManager + self.other_attention_cls = SlidingWindowManager + self._sort_other_attention_specs() + + for other_spec in self.sliding_window_specs: + assert other_spec.block_size % self.full_attention_block_size == 0, ( + "Sliding window attention block_size must be divisible by " + "full attention block_size." + ) + all_block_sizes = [self.full_attention_block_size] + [ + spec.block_size for spec in self.sliding_window_specs + ] + self.lcm_block_size = all_block_sizes[0] + for bs in all_block_sizes[1:]: + self.lcm_block_size = lcm(self.lcm_block_size, bs) + + self.max_group_id = max( + max(self.full_attention_group_ids), max(self.sliding_window_group_ids) ) + def _sort_other_attention_specs(self) -> None: + """ + Sort other attention specs by sliding window size (largest first). + This ensures we process larger windows first when finding cache hits. + """ + paired_data = [] + for group_id, spec in zip( + self.sliding_window_group_ids, + self.sliding_window_specs, + ): + sort_key = spec.sliding_window + paired_data.append((group_id, spec, sort_key)) + + paired_data.sort(key=lambda x: x[2], reverse=True) + + self.sliding_window_group_ids = [pair[0] for pair in paired_data] + self.sliding_window_specs = [pair[1] for pair in paired_data] + def find_longest_cache_hit( self, block_hashes: list[BlockHash], @@ -501,9 +558,26 @@ def find_longest_cache_hit( alignment_tokens=self.lcm_block_size, ) hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size + if len(self.sliding_window_group_ids) == 0: + # Mode 1: Original logic for single other type + return self._find_cache_hit_single_other_type( + block_hashes, hit_blocks_full_attn, hit_length + ) + else: + # Mode 2: Multiple sliding windows + return self._find_cache_hit_multiple_sliding_windows( + block_hashes, hit_blocks_full_attn, hit_length + ) + def _find_cache_hit_single_other_type( + self, + block_hashes: list[BlockHash], + hit_blocks_full_attn: tuple[list[KVCacheBlock], ...], + hit_length: int, + ) -> tuple[tuple[list[KVCacheBlock], ...], int]: # Next, find the cache hit for the other attention WITHIN # the cache hit of full attention. + assert self.other_spec is not None if self.other_spec.block_size == self.hash_block_size: # Common case. other_block_hashes: BlockHashList = block_hashes @@ -546,6 +620,71 @@ def find_longest_cache_hit( hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn return hit_blocks, hit_length + def _find_cache_hit_multiple_sliding_windows( + self, + block_hashes: list[BlockHash], + hit_blocks_full_attn: tuple[list[KVCacheBlock], ...], + hit_length: int, + ) -> tuple[tuple[list[KVCacheBlock], ...], int]: + """Mode 2: Full + multiple sliding windows.""" + hit_blocks_other_attn: list[list[KVCacheBlock]] = [] + min_hit_length = hit_length + + for i, other_spec in enumerate(self.sliding_window_specs): + other_block_size = other_spec.block_size + + if other_block_size == self.hash_block_size: + other_block_hashes: BlockHashList = block_hashes + else: + other_block_hashes = BlockHashListWithBlockSize( + block_hashes, self.hash_block_size, other_block_size + ) + + hit_blocks_single = self.other_attention_cls.find_longest_cache_hit( + block_hashes=other_block_hashes, + max_length=min_hit_length, + kv_cache_group_ids=[self.sliding_window_group_ids[i]], + block_pool=self.block_pool, + kv_cache_spec=other_spec, + use_eagle=self.use_eagle, + alignment_tokens=self.lcm_block_size, + ) + + current_hit_length = len(hit_blocks_single[0]) * other_block_size + min_hit_length = min(min_hit_length, current_hit_length) + + assert min_hit_length % self.full_attention_block_size == 0, ( + f"Cache hit length {min_hit_length} not aligned to " + f"full attention block size {self.full_attention_block_size}" + ) + + hit_blocks_other_attn.append(hit_blocks_single[0]) + + for group_hit_blocks in hit_blocks_full_attn: + num_blocks_to_keep = min_hit_length // self.full_attention_block_size + del group_hit_blocks[num_blocks_to_keep:] + + for i, other_spec in enumerate(self.sliding_window_specs): + other_block_size = other_spec.block_size + num_blocks_to_keep = min_hit_length // other_block_size + del hit_blocks_other_attn[i][num_blocks_to_keep:] + + hit_blocks: list[list[KVCacheBlock]] = [ + [] for _ in range(self.max_group_id + 1) + ] + + for group_id, blocks in zip( + self.full_attention_group_ids, hit_blocks_full_attn + ): + hit_blocks[group_id] = blocks + + for group_id, blocks in zip( + self.sliding_window_group_ids, hit_blocks_other_attn + ): + hit_blocks[group_id] = blocks + + return tuple(hit_blocks), min_hit_length + def get_kv_cache_coordinator( kv_cache_config: KVCacheConfig,