diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 486e5f9cd4c8..1d951cfdac40 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -107,7 +107,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: def make_kv_cache_config_hybrid_model( - block_size: int, num_blocks: int, second_spec_type: str = "sliding_window" + block_size: int, + num_blocks: int, + sliding_window_blocks: int, + second_spec_type: str = "sliding_window", ) -> KVCacheConfig: if second_spec_type == "sliding_window": second_spec = SlidingWindowSpec( @@ -115,7 +118,7 @@ def make_kv_cache_config_hybrid_model( num_kv_heads=1, head_size=1, dtype=torch.float32, - sliding_window=2 * block_size, + sliding_window=sliding_window_blocks * block_size, ) elif second_spec_type == "mamba": second_spec = MambaSpec( @@ -325,7 +328,7 @@ def test_prefill(hash_fn): def test_prefill_hybrid_model(): block_size = 16 manager = KVCacheManager( - make_kv_cache_config_hybrid_model(block_size, 21), + make_kv_cache_config_hybrid_model(block_size, 21, 2), max_model_len=8192, enable_caching=True, hash_block_size=block_size, @@ -334,7 +337,8 @@ def test_prefill_hybrid_model(): hash_fn = sha256 # Complete 3 blocks (48 tokens) - common_token_ids = [i for i in range(3) for _ in range(block_size)] + num_full_blocks = 3 + common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)] # Fully cache miss # Incomplete 1 block (7 tokens) @@ -375,6 +379,7 @@ def test_prefill_hybrid_model(): # Cache hit in the common prefix # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 + all_token_ids = common_token_ids + unique_token_ids req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 @@ -394,34 +399,13 @@ def test_prefill_hybrid_model(): manager.free(req0) manager.free(req1) - cached_block_hash_to_block_bak = copy.copy( - manager.block_pool.cached_block_hash_to_block._cache - ) - - def test_partial_request_hit( - request_id: str, - hash_to_evict: list[BlockHashWithGroupId], - expect_hit_length: int, - ): - req = make_request( - request_id, common_token_ids + unique_token_ids, block_size, sha256 - ) - for hash_with_group_id in hash_to_evict: - manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert len(req.block_hashes) == 3 - assert num_computed_tokens == expect_hit_length * block_size - for block_per_group in computed_blocks.blocks: - assert len(block_per_group) == num_computed_tokens // block_size - for hash_with_group_id in hash_to_evict: - manager.block_pool.cached_block_hash_to_block._cache[hash_with_group_id] = ( - cached_block_hash_to_block_bak[hash_with_group_id] - ) - manager.free(req) - # Evict the blocks outside sliding window, does not affect the hit length. - test_partial_request_hit( + _test_partial_request_hit( + manager, + block_size, + num_full_blocks, "2", + all_token_ids, [ make_block_hash_with_group_id(block_hashes[0], 1), make_block_hash_with_group_id(block_hashes[0], 2), @@ -430,13 +414,23 @@ def test_partial_request_hit( ) # Evict the first block of full attention, makes total cache miss. - test_partial_request_hit( - "3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0 + _test_partial_request_hit( + manager, + block_size, + num_full_blocks, + "3", + all_token_ids, + [make_block_hash_with_group_id(block_hashes[0], 0)], + 0, ) # Evict the last block of all layers, reduces the hit length to 2. - test_partial_request_hit( + _test_partial_request_hit( + manager, + block_size, + num_full_blocks, "4", + all_token_ids, [ make_block_hash_with_group_id(block_hashes[2], 0), make_block_hash_with_group_id(block_hashes[2], 1), @@ -446,18 +440,36 @@ def test_partial_request_hit( ) # Evict the last block of full attention, reduces the hit length to 2. - test_partial_request_hit( - "5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2 + _test_partial_request_hit( + manager, + block_size, + num_full_blocks, + "5", + all_token_ids, + [make_block_hash_with_group_id(block_hashes[2], 0)], + 2, ) # Evict the last block of sliding window, reduces the hit length to 2. - test_partial_request_hit( - "6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2 + _test_partial_request_hit( + manager, + block_size, + num_full_blocks, + "6", + all_token_ids, + [make_block_hash_with_group_id(block_hashes[2], 1)], + 2, ) # Evict the last block of sliding window, reduces the hit length to 2. - test_partial_request_hit( - "7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2 + _test_partial_request_hit( + manager, + block_size, + num_full_blocks, + "7", + all_token_ids, + [make_block_hash_with_group_id(block_hashes[2], 2)], + 2, ) # Evict different set of blocks for full attention and sliding window makes @@ -466,8 +478,12 @@ def test_partial_request_hit( # The cache hit length of sliding window is 2 * block_size. # Then it is cache miss as the two type of layers # have different hit length. - test_partial_request_hit( + _test_partial_request_hit( + manager, + block_size, + num_full_blocks, "8", + all_token_ids, [ make_block_hash_with_group_id(block_hashes[2], 0), make_block_hash_with_group_id(block_hashes[0], 1), @@ -477,6 +493,214 @@ def test_partial_request_hit( ) +def test_prefill_hybrid_model_eagle(): + block_size = 16 + kv_cache_config = make_kv_cache_config_hybrid_model(block_size, 31, 3) + manager = KVCacheManager( + kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + use_eagle=True, + ) + + hash_fn = sha256 + + # Complete 6 blocks (96 tokens) + num_full_blocks = 6 + common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)] + + # Fully cache miss + # Incomplete 1 block (7 tokens) + unique_token_ids = [6] * 7 + all_token_ids = common_token_ids + unique_token_ids + req0 = make_request("0", all_token_ids, block_size, hash_fn) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + assert len(req0.block_hashes) == len(all_token_ids) // block_size + assert not computed_blocks.blocks[0] + assert num_computed_tokens == 0 + blocks = manager.allocate_slots( + req0, len(all_token_ids), num_computed_tokens, computed_blocks + ) + block_ids = ( + [1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14], + [15, 16, 17, 18, 19, 20, 21], + ) + assert blocks is not None and blocks.get_block_ids() == block_ids + + # Check full block metadata + parent_block_hash = None + for i, full_block_ids in enumerate(zip(*(row[:-1] for row in block_ids))): + block_tokens = tuple(all_token_ids[i * block_size : (i + 1) * block_size]) + block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) + for group_id, block_id in enumerate(full_block_ids): + blk_hash = manager.block_pool.blocks[block_id].block_hash + assert blk_hash is not None + assert get_block_hash(blk_hash) == block_hash + assert get_group_id(blk_hash) == group_id + assert manager.block_pool.blocks[block_id].ref_cnt == 1 + parent_block_hash = block_hash + + # Check partial block metadata + for partial_block_id in (row[-1] for row in block_ids): + assert manager.block_pool.blocks[partial_block_id].block_hash is None + assert manager.block_pool.blocks[partial_block_id].ref_cnt == 1 + + # Cache hit in the common prefix + # Incomplete 1 block (5 tokens) + unique_token_ids = [6] * 5 + all_token_ids = common_token_ids + unique_token_ids + req1 = make_request("1", all_token_ids, block_size, hash_fn) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert len(req1.block_hashes) == num_full_blocks + assert computed_blocks.get_block_ids() == ( + [1, 2, 3, 4], + [0, 9, 10, 11], + [0, 16, 17, 18], + ) + assert num_computed_tokens == 4 * block_size + num_new_tokens = len(all_token_ids) - num_computed_tokens + blocks = manager.allocate_slots( + req1, num_new_tokens, num_computed_tokens, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ( + [22, 23, 24], + [25, 26, 27], + [28, 29, 30], + ) + for block_per_group in computed_blocks.blocks: + for block in block_per_group: + if block != manager.block_pool.null_block: + assert block.ref_cnt == 2 + + block_hashes = req1.block_hashes + manager.free(req0) + manager.free(req1) + + # Evict the blocks outside sliding window, does not affect the hit length. + _test_partial_request_hit( + manager, + block_size, + num_full_blocks, + "2", + all_token_ids, + [ + make_block_hash_with_group_id(block_hashes[0], 1), + make_block_hash_with_group_id(block_hashes[0], 2), + ], + 4, + ) + + # Evict the first block of full attention, makes total cache miss. + _test_partial_request_hit( + manager, + block_size, + num_full_blocks, + "3", + all_token_ids, + [make_block_hash_with_group_id(block_hashes[0], 0)], + 0, + ) + + # Evict the last block of all layers, reduces the hit length to 3. + _test_partial_request_hit( + manager, + block_size, + num_full_blocks, + "4", + all_token_ids, + [ + make_block_hash_with_group_id(block_hashes[-1], 0), + make_block_hash_with_group_id(block_hashes[-1], 1), + make_block_hash_with_group_id(block_hashes[-1], 2), + ], + 3, + ) + + # Evict the last block of full attention, reduces the hit length to 3. + _test_partial_request_hit( + manager, + block_size, + num_full_blocks, + "5", + all_token_ids, + [make_block_hash_with_group_id(block_hashes[-1], 0)], + 3, + ) + + # Since the last block of full attention is dropped for eagle, evict + # the second last block of sliding window, reduces the hit length to 3. + _test_partial_request_hit( + manager, + block_size, + num_full_blocks, + "6", + all_token_ids, + [make_block_hash_with_group_id(block_hashes[-2], 1)], + 3, + ) + + # Since the last block of full attention is dropped for eagle, evict + # the second last block of sliding window, reduces the hit length to 3. + _test_partial_request_hit( + manager, + block_size, + num_full_blocks, + "7", + all_token_ids, + [make_block_hash_with_group_id(block_hashes[-2], 2)], + 3, + ) + + # Evict different set of blocks for full attention and sliding window makes + # total cache miss. + # The cache hit length of full attention is 4 * block_size. + # The cache hit length of sliding window is 3 * block_size. + # Then it is cache miss as the two type of layers + # have different hit length. + _test_partial_request_hit( + manager, + block_size, + num_full_blocks, + "8", + all_token_ids, + [ + make_block_hash_with_group_id(block_hashes[-1], 0), + make_block_hash_with_group_id(block_hashes[0], 1), + make_block_hash_with_group_id(block_hashes[0], 2), + ], + 0, + ) + + +def _test_partial_request_hit( + manager: KVCacheManager, + block_size: int, + num_full_blocks, + request_id: str, + prompt_token_ids: list[int], + hash_to_evict: list[BlockHashWithGroupId], + expect_hit_length: int, +): + cached_block_hash_to_block_bak = copy.copy( + manager.block_pool.cached_block_hash_to_block._cache + ) + req = make_request(request_id, prompt_token_ids, block_size, sha256) + for hash_with_group_id in hash_to_evict: + manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + assert len(req.block_hashes) == num_full_blocks + assert num_computed_tokens == expect_hit_length * block_size + for block_per_group in computed_blocks.blocks: + assert len(block_per_group) == num_computed_tokens // block_size + for hash_with_group_id in hash_to_evict: + manager.block_pool.cached_block_hash_to_block._cache[hash_with_group_id] = ( + cached_block_hash_to_block_bak[hash_with_group_id] + ) + manager.free(req) + + def _make_hybrid_kv_cache_config( block_size: int, num_blocks: int, spec_types: list[str] ) -> KVCacheConfig: @@ -655,6 +879,140 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]): manager.free(req1) +# Test cases covering various combinations of KV cache spec types: +# - Varying number of groups (2, 3, or 4) +# - 0, 1, or 2 full attention groups +# - Sliding window with different window sizes +# - Interleaved group IDs (full attn and other types mixed) +# - Mamba spec combinations +_EAGLE_HYBRID_MODEL_TEST_CASES = [ + # 2 groups: 1 full + 1 other + pytest.param(["full", "sliding_window"], 2, id="2g-full+sw"), + pytest.param(["full", "mamba"], 3, id="2g-full+mamba"), + # 2 groups: 0 full (all other types) + pytest.param(["sliding_window", "mamba"], 3, id="2g-sw+mamba"), + pytest.param(["sliding_window", "sliding_window_large"], 2, id="2g-sw+sw_large"), + # 3 groups: 1 full + 2 others (same type) + pytest.param(["full", "sliding_window", "sliding_window"], 2, id="3g-full+2sw"), + pytest.param(["full", "mamba", "mamba"], 3, id="3g-full+2mamba"), + # 3 groups: 1 full + 2 others (different types) + pytest.param(["full", "sliding_window", "mamba"], 2, id="3g-full+sw+mamba"), + pytest.param( + ["full", "sliding_window", "sliding_window_large"], + 1, + id="3g-full+sw+sw_large", + ), + # 3 groups: 2 full + 1 other + pytest.param(["full", "full", "sliding_window"], 2, id="3g-2full+sw"), + pytest.param(["full", "full", "mamba"], 3, id="3g-2full+mamba"), + # 4 groups: interleaved (full, other, full, other) + pytest.param( + ["full", "sliding_window", "full", "sliding_window_large"], + 1, + id="4g-interleaved-full+sw+sw_large", + ), + pytest.param( + ["full", "mamba", "full", "mamba"], + 3, + id="4g-interleaved-full+mamba", + ), + # 4 groups: interleaved with different sliding windows + pytest.param( + ["full", "sliding_window", "full", "sliding_window_large"], + 1, + id="4g-interleaved-full+sw_mixed", + ), + # 4 groups: 0 full (all other types) + pytest.param( + ["sliding_window", "mamba", "sliding_window_large", "mamba"], + 2, + id="4g-sw+mamba+sw_large+mamba", + ), + # 4 groups: 2 full + 2 others (grouped) + pytest.param( + ["full", "full", "sliding_window", "mamba"], + 2, + id="4g-2full+sw+mamba", + ), +] + + +@pytest.mark.parametrize("spec_types,expect_hit_length", _EAGLE_HYBRID_MODEL_TEST_CASES) +def test_prefill_hybrid_model_combinations_eagle( + spec_types: list[str], expect_hit_length: int +): + """ + Test prefix caching with hybrid models containing various combinations of + KV cache spec types. + + This unified test covers: + - Various combinations (full attn + other attn types) + - Varying number of groups (2, 3, or 4) + - 0, 1, or 2 full attention groups in the combination + - Two sliding_window attn groups with different window sizes + - Interleaved group IDs (full attn and other types alternating) + - Mamba spec with other attention types + """ + block_size = 16 + num_groups = len(spec_types) + # Allocate enough blocks for all groups + num_blocks = 10 * num_groups + + kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types) + manager = KVCacheManager( + kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + use_eagle=True, + ) + + hash_fn = sha256 + + # Complete 3 blocks (48 tokens) + num_full_blocks = 4 + common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)] + unique_token_ids = [4] * 7 + all_token_ids = common_token_ids + unique_token_ids + + # First request: no cache hit initially + req0 = make_request("0", all_token_ids, block_size, hash_fn) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + + assert len(req0.block_hashes) == num_full_blocks + assert not computed_blocks.blocks[0] # No cache hit initially + assert num_computed_tokens == 0 + + blocks = manager.allocate_slots( + req0, len(all_token_ids), num_computed_tokens, computed_blocks + ) + assert blocks is not None + # Should have blocks for all groups + assert len(blocks.get_block_ids()) == num_groups + + # Second request: should hit cached blocks for common prefix + all_token_ids = common_token_ids + [6] * 5 + req1 = make_request("1", all_token_ids, block_size, hash_fn) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + + # Should hit cached blocks for all groups + assert num_computed_tokens == expect_hit_length * block_size + assert len(computed_blocks.blocks) == num_groups + + # Allocate and verify blocks for second request + blocks = manager.allocate_slots( + req1, + len(all_token_ids) - num_computed_tokens, + num_computed_tokens, + computed_blocks, + ) + assert blocks is not None + assert len(blocks.get_block_ids()) == num_groups + + manager.free(req0) + manager.free(req1) + + def test_prefill_plp(): """Test prefill with APC and some prompt logprobs (plp) requests. diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index c72fbb7be193..fdf4bb273f7c 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -485,9 +485,10 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: for spec, group_ids, manager_cls in self.attention_groups: is_full_attn = isinstance(spec, FullAttentionSpec) - # Full attention: reuse cached blocks (downward-closed property) + # Full attention or eagle: reuse cached blocks + # (downward-closed property) cached_blocks = hit_blocks_by_group[group_ids[0]] - if is_full_attn and cached_blocks is not None: + if (is_full_attn or self.use_eagle) and cached_blocks is not None: # For full attention, we only need to compute the cache hit # length once. Starting from the second iteration, if the # curr_hit_length is reduced by other groups, we can simply