diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 0600b813a489..486e5f9cd4c8 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -35,6 +35,7 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, + MambaSpec, SlidingWindowSpec, ) @@ -106,8 +107,23 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: def make_kv_cache_config_hybrid_model( - block_size: int, num_blocks: int + block_size: int, num_blocks: int, second_spec_type: str = "sliding_window" ) -> KVCacheConfig: + if second_spec_type == "sliding_window": + second_spec = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=2 * block_size, + ) + elif second_spec_type == "mamba": + second_spec = MambaSpec( + block_size=block_size, + shapes=(1, 1), + dtypes=(torch.float32,), + ) + return KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=[], @@ -123,16 +139,49 @@ def make_kv_cache_config_hybrid_model( ), KVCacheGroupSpec( ["layer2"], - SlidingWindowSpec( + second_spec, + ), + KVCacheGroupSpec( + ["layer3"], + second_spec, + ), + ], + ) + + +def make_kv_cache_config_three_types( + block_size: int, num_blocks: int, third_spec_type: str = "mamba" +) -> KVCacheConfig: + if third_spec_type == "mamba": + third_spec = MambaSpec( + block_size=block_size, + shapes=(1, 1), + dtypes=(torch.float32,), + ) + elif third_spec_type == "sliding_window": + third_spec = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=4 * block_size, + ) + + return KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer1"], + FullAttentionSpec( block_size=block_size, num_kv_heads=1, head_size=1, dtype=torch.float32, - sliding_window=2 * block_size, ), ), KVCacheGroupSpec( - ["layer3"], + ["layer2"], SlidingWindowSpec( block_size=block_size, num_kv_heads=1, @@ -141,6 +190,10 @@ def make_kv_cache_config_hybrid_model( sliding_window=2 * block_size, ), ), + KVCacheGroupSpec( + ["layer3"], + third_spec, + ), ], ) @@ -424,6 +477,184 @@ def test_partial_request_hit( ) +def _make_hybrid_kv_cache_config( + block_size: int, num_blocks: int, spec_types: list[str] +) -> KVCacheConfig: + """ + Create a KVCacheConfig with the specified spec types. + + Args: + block_size: The block size for KV cache. + num_blocks: The number of blocks in the KV cache. + spec_types: List of spec type strings. Supported types: + - "full": FullAttentionSpec + - "sliding_window": SlidingWindowSpec with window=2*block_size + - "sliding_window_large": SlidingWindowSpec with window=4*block_size + - "mamba": MambaSpec + """ + spec_map = { + "full": lambda: FullAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), + "sliding_window": lambda: SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=2 * block_size, + ), + "sliding_window_large": lambda: SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=4 * block_size, + ), + "mamba": lambda: MambaSpec( + block_size=block_size, + shapes=(1, 1), + dtypes=(torch.float32,), + ), + } + + kv_cache_groups = [ + KVCacheGroupSpec([f"layer{i}"], spec_map[spec_type]()) + for i, spec_type in enumerate(spec_types) + ] + + return KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=kv_cache_groups, + ) + + +# Test cases covering various combinations of KV cache spec types: +# - Varying number of groups (2, 3, or 4) +# - 0, 1, or 2 full attention groups +# - Sliding window with different window sizes +# - Interleaved group IDs (full attn and other types mixed) +# - Mamba spec combinations +_HYBRID_MODEL_TEST_CASES = [ + # 2 groups: 1 full + 1 other + pytest.param(["full", "sliding_window"], id="2g-full+sw"), + pytest.param(["full", "mamba"], id="2g-full+mamba"), + # 2 groups: 0 full (all other types) + pytest.param(["sliding_window", "mamba"], id="2g-sw+mamba"), + pytest.param(["sliding_window", "sliding_window_large"], id="2g-sw+sw_large"), + # 3 groups: 1 full + 2 others (same type) + pytest.param(["full", "sliding_window", "sliding_window"], id="3g-full+2sw"), + pytest.param(["full", "mamba", "mamba"], id="3g-full+2mamba"), + # 3 groups: 1 full + 2 others (different types) + pytest.param(["full", "sliding_window", "mamba"], id="3g-full+sw+mamba"), + pytest.param( + ["full", "sliding_window", "sliding_window_large"], + id="3g-full+sw+sw_large", + ), + # 3 groups: 2 full + 1 other + pytest.param(["full", "full", "sliding_window"], id="3g-2full+sw"), + pytest.param(["full", "full", "mamba"], id="3g-2full+mamba"), + # 4 groups: interleaved (full, other, full, other) + pytest.param( + ["full", "sliding_window", "full", "sliding_window_large"], + id="4g-interleaved-full+sw+sw_large", + ), + pytest.param( + ["full", "mamba", "full", "mamba"], + id="4g-interleaved-full+mamba", + ), + # 4 groups: interleaved with different sliding windows + pytest.param( + ["full", "sliding_window", "full", "sliding_window_large"], + id="4g-interleaved-full+sw_mixed", + ), + # 4 groups: 0 full (all other types) + pytest.param( + ["sliding_window", "mamba", "sliding_window_large", "mamba"], + id="4g-sw+mamba+sw_large+mamba", + ), + # 4 groups: 2 full + 2 others (grouped) + pytest.param( + ["full", "full", "sliding_window", "mamba"], + id="4g-2full+sw+mamba", + ), +] + + +@pytest.mark.parametrize("spec_types", _HYBRID_MODEL_TEST_CASES) +def test_prefill_hybrid_model_combinations(spec_types: list[str]): + """ + Test prefix caching with hybrid models containing various combinations of + KV cache spec types. + + This unified test covers: + - Various combinations (full attn + other attn types) + - Varying number of groups (2, 3, or 4) + - 0, 1, or 2 full attention groups in the combination + - Two sliding_window attn groups with different window sizes + - Interleaved group IDs (full attn and other types alternating) + - Mamba spec with other attention types + """ + block_size = 16 + num_groups = len(spec_types) + # Allocate enough blocks for all groups + num_blocks = 10 * num_groups + + kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types) + manager = KVCacheManager( + kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + ) + + hash_fn = sha256 + + # Complete 3 blocks (48 tokens) + common_token_ids = [i for i in range(3) for _ in range(block_size)] + unique_token_ids = [3] * 7 + all_token_ids = common_token_ids + unique_token_ids + + # First request: no cache hit initially + req0 = make_request("0", all_token_ids, block_size, hash_fn) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + + assert len(req0.block_hashes) == 3 + assert not computed_blocks.blocks[0] # No cache hit initially + assert num_computed_tokens == 0 + + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks[0]) * block_size, computed_blocks + ) + assert blocks is not None + # Should have blocks for all groups + assert len(blocks.get_block_ids()) == num_groups + + # Second request: should hit cached blocks for common prefix + req1 = make_request("1", common_token_ids + [4] * 5, block_size, hash_fn) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + + # Should hit cached blocks for all groups + assert num_computed_tokens == 3 * block_size + assert len(computed_blocks.blocks) == num_groups + + # Allocate and verify blocks for second request + blocks = manager.allocate_slots( + req1, + len(common_token_ids) + 5 - num_computed_tokens, + num_computed_tokens, + computed_blocks, + ) + assert blocks is not None + assert len(blocks.get_block_ids()) == num_groups + + manager.free(req0) + manager.free(req1) + + def test_prefill_plp(): """Test prefill with APC and some prompt logprobs (plp) requests. diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 1d00873e6062..4550e2b79562 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -14,7 +14,7 @@ ) from vllm.v1.core.single_type_kv_cache_manager import ( CrossAttentionManager, - FullAttentionManager, + SingleTypeKVCacheManager, get_manager_for_kv_cache_spec, ) from vllm.v1.kv_cache_interface import ( @@ -354,9 +354,6 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): """ KV cache coordinator for hybrid models with multiple KV cache types, and thus multiple kv cache groups. - To simplify `find_longest_cache_hit`, it only supports the combination of - two types of KV cache groups, and one of them must be full attention. - May extend to more general cases in the future. """ def __init__( @@ -397,70 +394,46 @@ def __init__( def verify_and_split_kv_cache_groups(self) -> None: """ - Verifies that the model has exactly two types of KV cache groups, and - one of them is full attention. Then, split the kv cache groups into full - attention groups and other groups. + Groups KV cache groups by their spec type for efficient batch processing + during cache hit lookup. """ - full_attention_spec: FullAttentionSpec | None = None - other_spec: KVCacheSpec | None = None - self.full_attention_group_ids: list[int] = [] - self.other_group_ids: list[int] = [] + attention_groups: list[ + tuple[KVCacheSpec, list[int], type[SingleTypeKVCacheManager]] + ] = [] + for i, g in enumerate(self.kv_cache_config.kv_cache_groups): - if isinstance(g.kv_cache_spec, FullAttentionSpec): - if full_attention_spec is None: - full_attention_spec = g.kv_cache_spec - else: - assert full_attention_spec == g.kv_cache_spec, ( - "HybridKVCacheCoordinator assumes exactly one type of " - "full attention groups now." + manager_cls = self.single_type_managers[i].__class__ + spec = g.kv_cache_spec + + # Try to find an existing group with the same spec + for existing_spec, group_ids, existing_cls in attention_groups: + if existing_spec == spec: + assert manager_cls is existing_cls, ( + "Expected same manager class for identical KV cache specs." ) - self.full_attention_group_ids.append(i) + group_ids.append(i) + break else: - if other_spec is None: - other_spec = g.kv_cache_spec - else: - assert other_spec == g.kv_cache_spec, ( - "HybridKVCacheCoordinator assumes " - "exactly one other type of groups now." - ) - self.other_group_ids.append(i) + attention_groups.append((spec, [i], manager_cls)) - assert full_attention_spec is not None, ( - "HybridKVCacheCoordinator assumes exactly one type of full " - "attention groups now." + assert len(attention_groups) > 1, ( + "HybridKVCacheCoordinator requires at least two attention groups." ) - assert other_spec is not None, ( - "HybridKVCacheCoordinator assumes exactly one type of other groups now." + + # Put full attention first: its efficient left-to-right scan provides + # a tighter initial bound, reducing work for subsequent groups. + self.attention_groups = sorted( + attention_groups, + key=lambda x: not isinstance(x[0], FullAttentionSpec), ) - self.full_attention_manager_cls = FullAttentionManager - self.other_attention_cls = self.single_type_managers[ - self.other_group_ids[0] - ].__class__ - self.full_attention_spec = full_attention_spec - self.other_spec = other_spec - self.full_attention_block_size = self.full_attention_spec.block_size - self.other_block_size = self.other_spec.block_size - # The LCM of the block sizes of full attention and other attention. + # The LCM of the block sizes of all attention types. # The cache hit length must be a multiple of the LCM of the block sizes # to make sure the cache hit length is a multiple of the block size of # each attention type. Requiring this because we don't support partial # block cache hit yet. - self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size) - - if max(self.full_attention_group_ids) < min(self.other_group_ids): - self.full_attn_first = True - elif max(self.other_group_ids) < min(self.full_attention_group_ids): - self.full_attn_first = False - else: - raise ValueError( - "HybridKVCacheCoordinator assumes the full " - "attention group ids and other attention group ids " - "do not interleave, either full attention group ids " - "are before other attention group ids or vice versa." - "This is for simplifying merging hit_blocks_full_attn and " - "hit_blocks_other_attn to hit_blocks." - ) + block_sizes = [spec.block_size for spec, _, _ in attention_groups] + self.lcm_block_size = lcm(*block_sizes) def find_longest_cache_hit( self, @@ -468,7 +441,12 @@ def find_longest_cache_hit( max_cache_hit_length: int, ) -> tuple[tuple[list[KVCacheBlock], ...], int]: """ - Find the longest cache hit for the request. + Find the longest cache hit using an iterative fixed-point algorithm. + + Each attention type either accepts the current candidate length or + reduces it. If any type reduces the length, restart checks over all + types. This converges because length monotonically decreases and is + bounded below by 0. Args: block_hashes: The block hashes of the request. @@ -476,75 +454,63 @@ def find_longest_cache_hit( Returns: A tuple containing: - - A list of the cache hit blocks for each single type manager. + - A tuple of the cache hit blocks for each single type manager. - The number of tokens of the longest cache hit. """ - # First, find the longest cache hit for full attention. - if self.full_attention_spec.block_size == self.hash_block_size: - # Common case. - full_attention_block_hashes: BlockHashList = block_hashes - else: - # block_size is a multiple of hash_block_size. This happens when different - # KV cache groups have different block sizes. In this case, we need to - # recalculate block_hashes at the granularity of block_size, using the - # original block_hashes (at the granularity of hash_block_size). - full_attention_block_hashes = BlockHashListWithBlockSize( - block_hashes, self.hash_block_size, self.full_attention_spec.block_size - ) - hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit( - block_hashes=full_attention_block_hashes, - max_length=max_cache_hit_length, - kv_cache_group_ids=self.full_attention_group_ids, - block_pool=self.block_pool, - kv_cache_spec=self.full_attention_spec, - use_eagle=self.use_eagle, - alignment_tokens=self.lcm_block_size, - ) - hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size - - # Next, find the cache hit for the other attention WITHIN - # the cache hit of full attention. - if self.other_spec.block_size == self.hash_block_size: - # Common case. - other_block_hashes: BlockHashList = block_hashes - else: - # Similar to the full attention case, here we need to recalculate - # block_hashes at the granularity of block_size, using the original - # block_hashes (at the granularity of hash_block_size). - other_block_hashes = BlockHashListWithBlockSize( - block_hashes, self.hash_block_size, self.other_spec.block_size + + def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: + if kv_cache_spec.block_size == self.hash_block_size: + return block_hashes + return BlockHashListWithBlockSize( + block_hashes, self.hash_block_size, kv_cache_spec.block_size ) - hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit( - block_hashes=other_block_hashes, - max_length=hit_length, - kv_cache_group_ids=self.other_group_ids, - block_pool=self.block_pool, - kv_cache_spec=self.other_spec, - use_eagle=self.use_eagle, - alignment_tokens=self.lcm_block_size, - ) - hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size - - # NOTE: the prefix cache hit length must be a multiple of block_size as - # we don't support partial block cache hit yet. The cache hit length - # of other attention is ensured to be a multiple of the block size of - # full attention layers in current implementation, because hit_length is - # a multiple of other attention's block size, and other attention's - # block size is a multiple of full attention's block size (verified in - # `verify_and_split_kv_cache_groups`). - assert hit_length % self.full_attention_block_size == 0 - - # Truncate the full attention cache hit to the length of the - # cache hit of the other attention. - for group_hit_blocks in hit_blocks_full_attn: - del group_hit_blocks[hit_length // self.full_attention_block_size :] - - # Merge the hit blocks of full attention and other attention. - if self.full_attn_first: - hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn - else: - hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn - return hit_blocks, hit_length + + num_groups = len(self.kv_cache_config.kv_cache_groups) + hit_length = max_cache_hit_length + hit_blocks_by_group: list[list[KVCacheBlock] | None] = [None] * num_groups + + while True: + curr_hit_length = hit_length + + for spec, group_ids, manager_cls in self.attention_groups: + is_full_attn = isinstance(spec, FullAttentionSpec) + + # Full attention: reuse cached blocks (downward-closed property) + cached_blocks = hit_blocks_by_group[group_ids[0]] + if is_full_attn and cached_blocks is not None: + # For full attention, we only need to compute the cache hit + # length once. Starting from the second iteration, if the + # curr_hit_length is reduced by other groups, we can simply + # keep the first (curr_hit_length // block_size) blocks from + # the last iteration. + num_blocks = curr_hit_length // spec.block_size + curr_hit_length = num_blocks * spec.block_size + for group_id in group_ids: + blocks = hit_blocks_by_group[group_id] + assert blocks is not None + del blocks[num_blocks:] + else: + hit_blocks = manager_cls.find_longest_cache_hit( + block_hashes=_get_block_hashes(spec), + max_length=curr_hit_length, + kv_cache_group_ids=group_ids, + block_pool=self.block_pool, + kv_cache_spec=spec, + use_eagle=self.use_eagle, + alignment_tokens=self.lcm_block_size, + ) + curr_hit_length = len(hit_blocks[0]) * spec.block_size + for group_id, blocks in zip(group_ids, hit_blocks): + hit_blocks_by_group[group_id] = blocks + + if curr_hit_length < hit_length: + hit_length = curr_hit_length + else: + break + + return tuple( + blocks if blocks is not None else [] for blocks in hit_blocks_by_group + ), hit_length def get_kv_cache_coordinator(