diff --git a/examples/online_serving/kv_events_subscriber.py b/examples/online_serving/kv_events_subscriber.py index 499ab1f39466..0512297fcf4f 100644 --- a/examples/online_serving/kv_events_subscriber.py +++ b/examples/online_serving/kv_events_subscriber.py @@ -43,10 +43,13 @@ class BlockStored(KVCacheEvent): prompt embeddings data, etc. for that specific block. """ + group_idx: int | None = None + class BlockRemoved(KVCacheEvent): block_hashes: list[ExternalBlockHash] medium: str | None + group_idx: int | None = None class AllBlocksCleared(KVCacheEvent): diff --git a/tests/distributed/test_kv_cache_events.py b/tests/distributed/test_kv_cache_events.py new file mode 100644 index 000000000000..57d1c9b546a1 --- /dev/null +++ b/tests/distributed/test_kv_cache_events.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.distributed.kv_events import BlockRemoved, BlockStored + +# Minimal ExternalBlockHash for testing (bytes are a valid ExternalBlockHash). +_FAKE_HASH: bytes = b"\xab" * 32 + + +def _make_block_stored(group_idx: int | None = None) -> BlockStored: + return BlockStored( + block_hashes=[_FAKE_HASH], + parent_block_hash=None, + token_ids=[1, 2, 3, 4], + block_size=4, + lora_id=None, + medium="GPU", + lora_name=None, + group_idx=group_idx, + ) + + +def _make_block_removed(group_idx: int | None = None) -> BlockRemoved: + return BlockRemoved( + block_hashes=[_FAKE_HASH], + medium="GPU", + group_idx=group_idx, + ) + + +def test_block_stored_default_group_idx_is_none(): + """group_idx defaults to None when not provided.""" + event = _make_block_stored() + assert event.group_idx is None + + +def test_block_removed_default_group_idx_is_none(): + """group_idx defaults to None when not provided.""" + event = _make_block_removed() + assert event.group_idx is None + + +@pytest.mark.parametrize("group_idx", [1, 2, 3]) +def test_block_stored_hash_differs_by_group_idx(group_idx: int): + """BlockStored events that differ only in group_idx must hash differently.""" + other_group_idx = group_idx + 1 + event_a = _make_block_stored(group_idx=group_idx) + event_b = _make_block_stored(group_idx=other_group_idx) + assert hash(event_a) != hash(event_b) + + +def test_block_stored_hash_same_for_equal_group_idx(): + """Two BlockStored events with identical fields produce the same hash.""" + event_a = _make_block_stored(group_idx=1) + event_b = _make_block_stored(group_idx=1) + assert hash(event_a) == hash(event_b) + + +@pytest.mark.parametrize("group_idx", [1, 2, 3]) +def test_block_removed_hash_differs_by_group_idx(group_idx: int): + """BlockRemoved events that differ only in group_idx must hash differently.""" + other_group_idx = group_idx + 1 + event_a = _make_block_removed(group_idx=group_idx) + event_b = _make_block_removed(group_idx=other_group_idx) + assert hash(event_a) != hash(event_b) + + +def test_block_removed_hash_same_for_equal_group_idx(): + """Two BlockRemoved events with identical fields produce the same hash.""" + event_a = _make_block_removed(group_idx=1) + event_b = _make_block_removed(group_idx=1) + assert hash(event_a) == hash(event_b) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index d8ecf28cbed1..046f04e0c79a 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -10,6 +10,7 @@ import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.config import ModelConfig, SchedulerConfig, VllmConfig +from vllm.config.kv_events import KVEventsConfig from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import ( MultiModalFeatureSpec, @@ -2137,3 +2138,30 @@ def test_unify_hybrid_kv_cache_specs(): with pytest.raises(ValueError): kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec) + + +def test_hma_not_disabled_when_kv_events_enabled(): + """ + Test enabling KV events must not force disable_hybrid_kv_cache_manager to True. + + This test guards against that regression by verifying that a VllmConfig + with kv_events_config set still resolves disable_hybrid_kv_cache_manager + to False (i.e. HMA remains enabled) when no other condition requires it + to be disabled. + """ + model_config = ModelConfig(max_model_len=16) + kv_events_config = KVEventsConfig( + enable_kv_cache_events=True, + publisher="null", + ) + + # Leave disable_hybrid_kv_cache_manager as None (the default) so that + # VllmConfig.__post_init__ resolves it automatically. + vllm_config = VllmConfig( + model_config=model_config, + kv_events_config=kv_events_config, + ) + + assert vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is False, ( + "kv_events_config must not force-disable the hybrid KV cache manager." + ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index b8b387fffd99..22220599f158 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1970,6 +1970,7 @@ def test_null_parent_block_hash(): block_size = 1 num_cached_blocks = 2 num_full_blocks = 4 + kv_cache_group_id = 0 pool = BlockPool( num_gpu_blocks=8, @@ -2002,7 +2003,7 @@ def test_null_parent_block_hash(): num_cached_blocks=num_cached_blocks, num_full_blocks=num_full_blocks, block_size=block_size, - kv_cache_group_id=0, + kv_cache_group_id=kv_cache_group_id, ) events = pool.take_events() @@ -2021,6 +2022,7 @@ def test_null_parent_block_hash(): for h in req.block_hashes[num_cached_blocks:num_full_blocks] ] assert event.block_hashes == expected_new_hashes + assert event.group_idx == kv_cache_group_id # Ensure we didn't accidentally assign a hash to the null block. assert pool.null_block.block_hash is None @@ -2087,6 +2089,153 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int): assert block_stored_event.block_size == block_size +@pytest.mark.parametrize("group_id", [0, 1, 2]) +def test_block_stored_event_group_idx(group_id: int): + """Test BlockStored events emitted by cache_full_blocks carry the correct + group_idx.""" + block_size = 4 + num_tokens = block_size * 2 + + pool = BlockPool( + num_gpu_blocks=5, + enable_caching=True, + hash_block_size=block_size, + enable_kv_cache_events=True, + ) + + req = make_request( + "req_grp_idx", + prompt_token_ids=list(range(num_tokens)), + block_size=block_size, + hash_fn=sha256, + ) + + blocks = pool.get_new_blocks(2) + pool.cache_full_blocks( + request=req, + blocks=blocks, + num_cached_blocks=0, + num_full_blocks=2, + block_size=block_size, + kv_cache_group_id=group_id, + ) + + events = pool.take_events() + assert len(events) == 1 + assert isinstance(events[0], BlockStored) + assert events[0].group_idx == group_id + + +def test_block_stored_event_group_idx_multiple_groups(): + """ + Test BlockStored events for separate HMA groups that each carry the + correct group_idx. + + Simulates the HMA scenario where full-attention blocks (group 0) and + sliding-window blocks (group 1) are cached independently and must be + distinguishable by consumers doing HMA-aware prefix-cache routing. + """ + block_size = 4 + num_tokens = block_size * 2 + + # null block + 4 usable (2 per group) + pool = BlockPool( + num_gpu_blocks=5, + enable_caching=True, + hash_block_size=block_size, + enable_kv_cache_events=True, + ) + + req = make_request( + "req_multi_grp", + prompt_token_ids=list(range(num_tokens)), + block_size=block_size, + hash_fn=sha256, + ) + + # Cache blocks for group 0 (full-attention) + blocks_grp0 = pool.get_new_blocks(2) + pool.cache_full_blocks( + request=req, + blocks=blocks_grp0, + num_cached_blocks=0, + num_full_blocks=2, + block_size=block_size, + kv_cache_group_id=0, + ) + + # Cache blocks for group 1 (sliding-window) + blocks_grp1 = pool.get_new_blocks(2) + pool.cache_full_blocks( + request=req, + blocks=blocks_grp1, + num_cached_blocks=0, + num_full_blocks=2, + block_size=block_size, + kv_cache_group_id=1, + ) + + events = pool.take_events() + assert len(events) == 2 + assert isinstance(events[0], BlockStored) + assert events[0].group_idx == 0 + assert isinstance(events[1], BlockStored) + assert events[1].group_idx == 1 + + +@pytest.mark.parametrize("group_id", [0, 1, 2]) +def test_block_removed_event_group_idx(group_id: int): + """ + Test BlockRemoved events emitted on eviction carry the group_idx extracted + from the evicted block's BlockHashWithGroupId via get_group_id(). + """ + block_size = 4 + num_tokens = block_size * 2 + + # null block + 4 usable; allocate all 4, cache 2, free all, re-allocate + # all 4 so the 2 cached blocks are forced through _maybe_evict_cached_block. + pool = BlockPool( + num_gpu_blocks=5, + enable_caching=True, + hash_block_size=block_size, + enable_kv_cache_events=True, + ) + + req = make_request( + "req_evict_grp", + prompt_token_ids=list(range(num_tokens)), + block_size=block_size, + hash_fn=sha256, + ) + + # Allocate all usable blocks and cache the first two for the target group. + all_blocks = pool.get_new_blocks(4) + pool.cache_full_blocks( + request=req, + blocks=all_blocks, + num_cached_blocks=0, + num_full_blocks=2, + block_size=block_size, + kv_cache_group_id=group_id, + ) + + # Drain the BlockStored events so only eviction events remain later. + pool.take_events() + + # Return all blocks to the free queue so they become eviction candidates. + pool.free_blocks(all_blocks) + + # Re-allocate all blocks; the two with hashes trigger BlockRemoved events. + pool.get_new_blocks(4) + + events = pool.take_events() + removed_events = [e for e in events if isinstance(e, BlockRemoved)] + + assert len(removed_events) == 2 + for event in removed_events: + assert event.group_idx == group_id + + def test_eagle_enabled_removes_last_block(): """Verify Eagle does NOT remove blocks when request length is divisible by block size.""" diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index d2a9ceb2dee2..ab5946ad3ba6 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -936,6 +936,13 @@ async def test_engine_core_client_future_utility_async( client.shutdown() +@pytest.mark.parametrize( + "model_name,num_groups", + [ + ("meta-llama/Llama-3.2-1B-Instruct", 1), + ("google/gemma-3-1b-it", 7), + ], +) @pytest.mark.parametrize( "multiprocessing_mode,publisher_config", [(True, "tcp"), (False, "inproc")], @@ -944,12 +951,14 @@ async def test_engine_core_client_future_utility_async( def test_kv_cache_events( multiprocessing_mode: bool, publisher_config, + model_name: str, + num_groups: int, ): block_size = 16 num_blocks = 2 engine_args = EngineArgs( - model=MODEL_NAME, + model=model_name, enforce_eager=True, enable_prefix_caching=True, block_size=block_size, @@ -985,26 +994,29 @@ def test_kv_cache_events( assert result is not None, "No message received" seq, received = result - assert seq == 0, "Sequence number mismatch" - assert len(received.events) == 1, "We should have exactly one BlockStored event" - event = received.events[0] - assert isinstance(event, BlockStored), "We should have a BlockStored event" - assert len(event.block_hashes) == num_blocks, ( - "We should have a BlockStored event with 2 block_hashes" - ) - assert event.block_size == block_size, ( - "Block size should be the same as the block size" - ) - assert event.parent_block_hash is None, "Parent block hash should be None" - assert event.lora_id is None, "Lora id should be None" - assert event.lora_name is None, "Lora name should be None" - assert len(event.token_ids) == num_blocks * block_size, ( - "Token ids should be the same as the custom tokens" - ) - assert event.token_ids == custom_tokens, ( - "Token ids should be the same as the custom tokens" + assert len(received.events) == num_groups, ( + f"Expected {num_groups} BlockStored event(s), got {len(received.events)}" ) + + for index, event in enumerate(received.events): + assert isinstance(event, BlockStored), "We should have a BlockStored event" + assert len(event.block_hashes) == num_blocks, ( + "We should have a BlockStored event with 2 block_hashes" + ) + assert event.block_size == block_size, ( + "Block size should be the same as the block size" + ) + assert event.parent_block_hash is None, "Parent block hash should be None" + assert event.lora_id is None, "Lora id should be None" + assert event.lora_name is None, "Lora name should be None" + assert len(event.token_ids) == num_blocks * block_size, ( + "Token ids should be the same as the custom tokens" + ) + assert event.token_ids == custom_tokens, ( + "Token ids should be the same as the custom tokens" + ) + assert event.group_idx == index finally: client.shutdown() subscriber.close() diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 3b8431e9530a..1f56b73914ce 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1229,9 +1229,6 @@ def has_blocked_weights(): if not current_platform.support_hybrid_kv_cache(): # Hybrid KV cache manager is not supported on non-GPU platforms. need_disable_hybrid_kv_cache_manager = True - if self.kv_events_config is not None: - # Hybrid KV cache manager is not compatible with KV events. - need_disable_hybrid_kv_cache_manager = True if ( self.model_config is not None and self.model_config.attention_chunk_size is not None diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 21ec7a36e984..d3e304f8b603 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -67,6 +67,8 @@ class BlockStored(KVCacheEvent): KV cache consumers to reconstruct block hashes. """ + group_idx: int | None = None + def __hash__(self) -> int: return hash( ( @@ -77,6 +79,7 @@ def __hash__(self) -> int: self.lora_id, self.medium, tuple(self.extra_keys) if self.extra_keys else None, + self.group_idx, ) ) @@ -84,9 +87,16 @@ def __hash__(self) -> int: class BlockRemoved(KVCacheEvent): block_hashes: list[ExternalBlockHash] medium: str | None + group_idx: int | None = None def __hash__(self) -> int: - return hash((tuple(self.block_hashes), self.medium)) + return hash( + ( + tuple(self.block_hashes), + self.medium, + self.group_idx, + ) + ) class AllBlocksCleared(KVCacheEvent): diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 4b62d2a4c642..9097079ef33a 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -22,6 +22,7 @@ KVCacheBlock, generate_block_hash_extra_keys, get_block_hash, + get_group_id, make_block_hash_with_group_id, maybe_convert_block_hash, ) @@ -314,6 +315,7 @@ def cache_full_blocks( if request.lora_request else None, extra_keys=extra_keys_list if extra_keys_list else None, + group_idx=kv_cache_group_id, ) ) @@ -377,14 +379,11 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: block.reset_hash() if self.enable_kv_cache_events: - # FIXME (Chen): Not sure whether we should return `hash_value` - # or `(hash_value, group_id)` here. But it's fine now because - # we disable hybrid kv cache manager when kv cache event is - # enabled, so there is only one group. self.kv_event_queue.append( BlockRemoved( block_hashes=[maybe_convert_block_hash(get_block_hash(block_hash))], medium=MEDIUM_GPU, + group_idx=get_group_id(block_hash), ) ) return True