diff --git a/tests/v1/kv_connector/unit/test_lmcache_connector.py b/tests/v1/kv_connector/unit/test_lmcache_connector.py index c3df2b68b1ff..d2d955161f26 100644 --- a/tests/v1/kv_connector/unit/test_lmcache_connector.py +++ b/tests/v1/kv_connector/unit/test_lmcache_connector.py @@ -7,7 +7,7 @@ from vllm.distributed.kv_events import BlockStored from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import ( LMCacheConnectorV1, - LMCacheKVEvents, + LMCacheWorkerMetadata, ) from vllm.v1.outputs import KVConnectorOutput @@ -50,12 +50,12 @@ def __init__( def mock_connector(): """Create a mock LMCacheConnectorV1 instance with mocked dependencies.""" connector = MagicMock(spec=LMCacheConnectorV1) - connector._kv_cache_events = None + connector._accumulated_worker_meta = None connector._lmcache_engine = MagicMock() # Make the methods use the real implementation - connector.get_kv_connector_kv_cache_events = ( - LMCacheConnectorV1.get_kv_connector_kv_cache_events.__get__( + connector.build_connector_worker_meta = ( + LMCacheConnectorV1.build_connector_worker_meta.__get__( connector, LMCacheConnectorV1 ) ) @@ -71,14 +71,14 @@ def mock_connector(): return connector -class TestGetKVConnectorKVCacheEvents: - """Test get_kv_connector_kv_cache_events method.""" +class TestBuildConnectorWorkerMeta: + """Test build_connector_worker_meta method.""" def test_returns_none_when_no_events(self, mock_connector): """Test that None is returned when lmcache engine has no events.""" mock_connector._lmcache_engine.get_kv_events.return_value = None - result = mock_connector.get_kv_connector_kv_cache_events() + result = mock_connector.build_connector_worker_meta() assert result is None mock_connector._lmcache_engine.get_kv_events.assert_called_once() @@ -87,7 +87,7 @@ def test_returns_none_when_empty_list(self, mock_connector): """Test that None is returned when lmcache engine returns empty list.""" mock_connector._lmcache_engine.get_kv_events.return_value = [] - result = mock_connector.get_kv_connector_kv_cache_events() + result = mock_connector.build_connector_worker_meta() assert result is None @@ -97,13 +97,13 @@ def test_converts_single_event(self, mock_connector, mock_lmcache_engine_event): mock_lmcache_engine_event ] - result = mock_connector.get_kv_connector_kv_cache_events() + result = mock_connector.build_connector_worker_meta() assert result is not None - assert isinstance(result, LMCacheKVEvents) - assert result.get_number_of_workers() == 1 + assert isinstance(result, LMCacheWorkerMetadata) + assert result.num_workers == 1 - events = result.get_all_events() + events = result.kv_events assert len(events) == 1 assert isinstance(events[0], BlockStored) assert events[0].block_hashes == ["hash1", "hash2"] @@ -130,12 +130,12 @@ def __init__(self, i): events = [MockEvent(i) for i in range(5)] mock_connector._lmcache_engine.get_kv_events.return_value = events - result = mock_connector.get_kv_connector_kv_cache_events() + result = mock_connector.build_connector_worker_meta() assert result is not None - assert isinstance(result, LMCacheKVEvents) + assert isinstance(result, LMCacheWorkerMetadata) - converted_events = result.get_all_events() + converted_events = result.kv_events assert len(converted_events) == 5 for i, event in enumerate(converted_events): @@ -161,9 +161,9 @@ def __init__(self): MockEventWithLora() ] - result = mock_connector.get_kv_connector_kv_cache_events() + result = mock_connector.build_connector_worker_meta() - events = result.get_all_events() + events = result.kv_events event = events[0] assert event.block_hashes == ["hash_a", "hash_b", "hash_c"] @@ -191,39 +191,35 @@ def __init__(self): MockEventNoParent() ] - result = mock_connector.get_kv_connector_kv_cache_events() + result = mock_connector.build_connector_worker_meta() - events = result.get_all_events() + events = result.kv_events assert events[0].parent_block_hash is None class TestUpdateConnectorOutput: """Test update_connector_output method.""" - def test_does_nothing_when_kv_cache_events_is_none(self, mock_connector): - """Test that method returns early when kv_cache_events is None.""" - connector_output = KVConnectorOutput(kv_cache_events=None) + def test_does_nothing_when_worker_meta_is_none(self, mock_connector): + """Test that method returns early when kv_connector_worker_meta is None.""" + connector_output = KVConnectorOutput(kv_connector_worker_meta=None) mock_connector.update_connector_output(connector_output) - assert mock_connector._kv_cache_events is None + assert mock_connector._accumulated_worker_meta is None - def test_does_nothing_when_kv_cache_events_is_not_lmcache_kv_events( - self, mock_connector - ): - """Test that method returns early when kv_cache_events is not - LMCacheKVEvents.""" - # Create a mock object that is not LMCacheKVEvents - fake_events = MagicMock() - connector_output = KVConnectorOutput(kv_cache_events=fake_events) + def test_does_nothing_when_worker_meta_is_not_lmcache(self, mock_connector): + """Test that method returns early when worker_meta is not + LMCacheWorkerMetadata.""" + fake_meta = MagicMock() + connector_output = KVConnectorOutput(kv_connector_worker_meta=fake_meta) mock_connector.update_connector_output(connector_output) - assert mock_connector._kv_cache_events is None + assert mock_connector._accumulated_worker_meta is None - def test_sets_kv_cache_events_when_none(self, mock_connector): - """Test that _kv_cache_events is set when it was None.""" - kv_events = LMCacheKVEvents(num_workers=1) + def test_sets_worker_meta_when_none(self, mock_connector): + """Test that _accumulated_worker_meta is set when it was None.""" event = BlockStored( block_hashes=["hash1"], parent_block_hash=None, @@ -233,18 +229,17 @@ def test_sets_kv_cache_events_when_none(self, mock_connector): medium="GPU", lora_name=None, ) - kv_events.add_events([event]) + worker_meta = LMCacheWorkerMetadata(kv_events=[event], num_workers=1) - connector_output = KVConnectorOutput(kv_cache_events=kv_events) + connector_output = KVConnectorOutput(kv_connector_worker_meta=worker_meta) mock_connector.update_connector_output(connector_output) - assert mock_connector._kv_cache_events is kv_events + assert mock_connector._accumulated_worker_meta is worker_meta - def test_adds_events_when_kv_cache_events_already_exists(self, mock_connector): - """Test that events are added when _kv_cache_events already exists.""" - # Set up existing events - existing_events = LMCacheKVEvents(num_workers=2) + def test_aggregates_when_worker_meta_already_exists(self, mock_connector): + """Test that worker meta is aggregated when _accumulated_worker_meta + already exists.""" event1 = BlockStored( block_hashes=["hash1"], parent_block_hash=None, @@ -254,13 +249,9 @@ def test_adds_events_when_kv_cache_events_already_exists(self, mock_connector): medium="GPU", lora_name=None, ) - existing_events.add_events([event1]) - existing_events.add_events([event1]) # Simulate 2 workers reporting + existing_meta = LMCacheWorkerMetadata(kv_events=[event1, event1], num_workers=2) + mock_connector._accumulated_worker_meta = existing_meta - mock_connector._kv_cache_events = existing_events - - # Create new events to add - new_events = LMCacheKVEvents(num_workers=1) event2 = BlockStored( block_hashes=["hash2"], parent_block_hash=None, @@ -270,28 +261,22 @@ def test_adds_events_when_kv_cache_events_already_exists(self, mock_connector): medium="GPU", lora_name=None, ) - new_events.add_events([event2]) - - connector_output = KVConnectorOutput(kv_cache_events=new_events) + new_meta = LMCacheWorkerMetadata(kv_events=[event2], num_workers=1) + connector_output = KVConnectorOutput(kv_connector_worker_meta=new_meta) mock_connector.update_connector_output(connector_output) - # Check that events were added - all_events = mock_connector._kv_cache_events.get_all_events() + # Check that events were aggregated + all_events = mock_connector._accumulated_worker_meta.kv_events assert len(all_events) == 3 # 2 from existing + 1 from new assert event1 in all_events assert event2 in all_events - def test_increments_workers_when_kv_cache_events_already_exists( - self, mock_connector - ): - """Test that worker count is incremented correctly.""" - # Set up existing events with 2 workers - existing_events = LMCacheKVEvents(num_workers=2) - mock_connector._kv_cache_events = existing_events + def test_increments_workers_when_worker_meta_already_exists(self, mock_connector): + """Test that worker count is aggregated correctly.""" + existing_meta = LMCacheWorkerMetadata(kv_events=[], num_workers=2) + mock_connector._accumulated_worker_meta = existing_meta - # Create new events from 3 workers - new_events = LMCacheKVEvents(num_workers=3) event = BlockStored( block_hashes=["hash1"], parent_block_hash=None, @@ -301,19 +286,17 @@ def test_increments_workers_when_kv_cache_events_already_exists( medium="GPU", lora_name=None, ) - new_events.add_events([event]) - - connector_output = KVConnectorOutput(kv_cache_events=new_events) + new_meta = LMCacheWorkerMetadata(kv_events=[event], num_workers=3) + connector_output = KVConnectorOutput(kv_connector_worker_meta=new_meta) mock_connector.update_connector_output(connector_output) # Worker count should be 2 + 3 = 5 - assert mock_connector._kv_cache_events.get_number_of_workers() == 5 + assert mock_connector._accumulated_worker_meta.num_workers == 5 def test_multiple_updates(self, mock_connector): """Test multiple consecutive updates.""" # First update - events1 = LMCacheKVEvents(num_workers=1) event1 = BlockStored( block_hashes=["hash1"], parent_block_hash=None, @@ -323,12 +306,11 @@ def test_multiple_updates(self, mock_connector): medium="GPU", lora_name=None, ) - events1.add_events([event1]) - output1 = KVConnectorOutput(kv_cache_events=events1) + meta1 = LMCacheWorkerMetadata(kv_events=[event1], num_workers=1) + output1 = KVConnectorOutput(kv_connector_worker_meta=meta1) mock_connector.update_connector_output(output1) # Second update - events2 = LMCacheKVEvents(num_workers=2) event2 = BlockStored( block_hashes=["hash2"], parent_block_hash=None, @@ -338,12 +320,11 @@ def test_multiple_updates(self, mock_connector): medium="GPU", lora_name=None, ) - events2.add_events([event2]) - output2 = KVConnectorOutput(kv_cache_events=events2) + meta2 = LMCacheWorkerMetadata(kv_events=[event2], num_workers=2) + output2 = KVConnectorOutput(kv_connector_worker_meta=meta2) mock_connector.update_connector_output(output2) # Third update - events3 = LMCacheKVEvents(num_workers=1) event3 = BlockStored( block_hashes=["hash3"], parent_block_hash=None, @@ -353,19 +334,18 @@ def test_multiple_updates(self, mock_connector): medium="GPU", lora_name=None, ) - events3.add_events([event3]) - output3 = KVConnectorOutput(kv_cache_events=events3) + meta3 = LMCacheWorkerMetadata(kv_events=[event3], num_workers=1) + output3 = KVConnectorOutput(kv_connector_worker_meta=meta3) mock_connector.update_connector_output(output3) # Check final state - all_events = mock_connector._kv_cache_events.get_all_events() + all_events = mock_connector._accumulated_worker_meta.kv_events assert len(all_events) == 3 - assert mock_connector._kv_cache_events.get_number_of_workers() == 4 # 1+2+1 + assert mock_connector._accumulated_worker_meta.num_workers == 4 # 1+2+1 def test_updates_with_empty_events(self, mock_connector): """Test updating with empty event lists.""" # First update with actual events - events1 = LMCacheKVEvents(num_workers=1) event1 = BlockStored( block_hashes=["hash1"], parent_block_hash=None, @@ -375,28 +355,27 @@ def test_updates_with_empty_events(self, mock_connector): medium="GPU", lora_name=None, ) - events1.add_events([event1]) - output1 = KVConnectorOutput(kv_cache_events=events1) + meta1 = LMCacheWorkerMetadata(kv_events=[event1], num_workers=1) + output1 = KVConnectorOutput(kv_connector_worker_meta=meta1) mock_connector.update_connector_output(output1) # Second update with empty events - events2 = LMCacheKVEvents(num_workers=2) - # No events added - output2 = KVConnectorOutput(kv_cache_events=events2) + meta2 = LMCacheWorkerMetadata(kv_events=[], num_workers=2) + output2 = KVConnectorOutput(kv_connector_worker_meta=meta2) mock_connector.update_connector_output(output2) # Should still have the original event - all_events = mock_connector._kv_cache_events.get_all_events() + all_events = mock_connector._accumulated_worker_meta.kv_events assert len(all_events) == 1 - assert mock_connector._kv_cache_events.get_number_of_workers() == 3 + assert mock_connector._accumulated_worker_meta.num_workers == 3 class TestTakeEvents: """Test take_events method.""" - def test_yields_nothing_when_kv_cache_events_is_none(self, mock_connector): - """Test that nothing is yielded when _kv_cache_events is None.""" - mock_connector._kv_cache_events = None + def test_yields_nothing_when_accumulated_meta_is_none(self, mock_connector): + """Test that nothing is yielded when _accumulated_worker_meta is None.""" + mock_connector._accumulated_worker_meta = None events = list(mock_connector.take_events()) @@ -404,8 +383,6 @@ def test_yields_nothing_when_kv_cache_events_is_none(self, mock_connector): def test_yields_events_and_clears(self, mock_connector): """Test that events are yielded and then cleared.""" - # Set up events - kv_events = LMCacheKVEvents(num_workers=1) event1 = BlockStored( block_hashes=["hash1"], parent_block_hash=None, @@ -424,8 +401,9 @@ def test_yields_events_and_clears(self, mock_connector): medium="GPU", lora_name=None, ) - kv_events.add_events([event1, event2]) - mock_connector._kv_cache_events = kv_events + mock_connector._accumulated_worker_meta = LMCacheWorkerMetadata( + kv_events=[event1, event2], num_workers=1 + ) # Take events events = list(mock_connector.take_events()) @@ -435,13 +413,11 @@ def test_yields_events_and_clears(self, mock_connector): assert event1 in events assert event2 in events - # Check that _kv_cache_events was cleared - assert mock_connector._kv_cache_events is None + # Check that _accumulated_worker_meta was cleared + assert mock_connector._accumulated_worker_meta is None def test_aggregates_before_yielding(self, mock_connector): - """Test that events are aggregated before yielding.""" - # Set up events from multiple workers - kv_events = LMCacheKVEvents(num_workers=3) + """Test that events are aggregated (consensus) before yielding.""" common_event = BlockStored( block_hashes=["hash_common"], parent_block_hash=None, @@ -461,15 +437,16 @@ def test_aggregates_before_yielding(self, mock_connector): lora_name=None, ) - # All 3 workers report common_event - kv_events.add_events([common_event]) - kv_events.add_events([common_event]) - kv_events.add_events([common_event]) - - # Only 1 worker reports uncommon_event - kv_events.add_events([uncommon_event]) - - mock_connector._kv_cache_events = kv_events + # Simulate 3 workers: all report common_event, only 1 reports uncommon + mock_connector._accumulated_worker_meta = LMCacheWorkerMetadata( + kv_events=[ + common_event, + common_event, + common_event, + uncommon_event, + ], + num_workers=3, + ) # Take events events = list(mock_connector.take_events()) @@ -481,7 +458,6 @@ def test_aggregates_before_yielding(self, mock_connector): def test_multiple_take_events_calls(self, mock_connector): """Test calling take_events multiple times.""" # First call with events - kv_events1 = LMCacheKVEvents(num_workers=1) event1 = BlockStored( block_hashes=["hash1"], parent_block_hash=None, @@ -491,20 +467,20 @@ def test_multiple_take_events_calls(self, mock_connector): medium="GPU", lora_name=None, ) - kv_events1.add_events([event1]) - mock_connector._kv_cache_events = kv_events1 + mock_connector._accumulated_worker_meta = LMCacheWorkerMetadata( + kv_events=[event1], num_workers=1 + ) events1 = list(mock_connector.take_events()) assert len(events1) == 1 assert events1[0] == event1 - assert mock_connector._kv_cache_events is None + assert mock_connector._accumulated_worker_meta is None # Second call with no events events2 = list(mock_connector.take_events()) assert events2 == [] # Third call after adding new events - kv_events2 = LMCacheKVEvents(num_workers=1) event2 = BlockStored( block_hashes=["hash2"], parent_block_hash=None, @@ -514,8 +490,9 @@ def test_multiple_take_events_calls(self, mock_connector): medium="GPU", lora_name=None, ) - kv_events2.add_events([event2]) - mock_connector._kv_cache_events = kv_events2 + mock_connector._accumulated_worker_meta = LMCacheWorkerMetadata( + kv_events=[event2], num_workers=1 + ) events3 = list(mock_connector.take_events()) assert len(events3) == 1 @@ -523,8 +500,6 @@ def test_multiple_take_events_calls(self, mock_connector): def test_yields_empty_after_aggregation_removes_all(self, mock_connector): """Test that nothing is yielded if aggregation removes all events.""" - # Set up events from 2 workers with no common events - kv_events = LMCacheKVEvents(num_workers=2) event1 = BlockStored( block_hashes=["hash1"], parent_block_hash=None, @@ -544,19 +519,62 @@ def test_yields_empty_after_aggregation_removes_all(self, mock_connector): lora_name=None, ) - # Worker 1 reports event1 - kv_events.add_events([event1]) - # Worker 2 reports event2 - kv_events.add_events([event2]) - - mock_connector._kv_cache_events = kv_events + # 2 workers, each reporting a different event -> no consensus + mock_connector._accumulated_worker_meta = LMCacheWorkerMetadata( + kv_events=[event1, event2], num_workers=2 + ) # Take events events = list(mock_connector.take_events()) # No common events, so nothing should be yielded assert events == [] - assert mock_connector._kv_cache_events is None + assert mock_connector._accumulated_worker_meta is None + + +class TestLMCacheWorkerMetadataAggregate: + """Test LMCacheWorkerMetadata.aggregate method.""" + + def test_aggregate_combines_events(self): + """Test that aggregate combines events from both metadata objects.""" + event1 = BlockStored( + block_hashes=["hash1"], + parent_block_hash=None, + token_ids=[1], + block_size=16, + lora_id=None, + medium="GPU", + lora_name=None, + ) + event2 = BlockStored( + block_hashes=["hash2"], + parent_block_hash=None, + token_ids=[2], + block_size=16, + lora_id=None, + medium="GPU", + lora_name=None, + ) + + meta1 = LMCacheWorkerMetadata(kv_events=[event1], num_workers=1) + meta2 = LMCacheWorkerMetadata(kv_events=[event2], num_workers=1) + + result = meta1.aggregate(meta2) + + assert isinstance(result, LMCacheWorkerMetadata) + assert len(result.kv_events) == 2 + assert event1 in result.kv_events + assert event2 in result.kv_events + assert result.num_workers == 2 + + def test_aggregate_sums_worker_counts(self): + """Test that aggregate sums worker counts correctly.""" + meta1 = LMCacheWorkerMetadata(kv_events=[], num_workers=3) + meta2 = LMCacheWorkerMetadata(kv_events=[], num_workers=2) + + result = meta1.aggregate(meta2) + + assert result.num_workers == 5 class TestIntegrationScenarios: @@ -564,26 +582,26 @@ class TestIntegrationScenarios: def test_full_workflow(self, mock_connector, mock_lmcache_engine_event): """Test a complete workflow from getting events to taking them.""" - # Step 1: Get events from lmcache engine + # Step 1: Build worker metadata from lmcache engine mock_connector._lmcache_engine.get_kv_events.return_value = [ mock_lmcache_engine_event ] - kv_events = mock_connector.get_kv_connector_kv_cache_events() + worker_meta = mock_connector.build_connector_worker_meta() - assert kv_events is not None - assert len(kv_events.get_all_events()) == 1 + assert worker_meta is not None + assert len(worker_meta.kv_events) == 1 # Step 2: Update connector output (simulate receiving from worker) - output1 = KVConnectorOutput(kv_cache_events=kv_events) + output1 = KVConnectorOutput(kv_connector_worker_meta=worker_meta) mock_connector.update_connector_output(output1) - assert mock_connector._kv_cache_events is not None + assert mock_connector._accumulated_worker_meta is not None # Step 3: Take events taken_events = list(mock_connector.take_events()) assert len(taken_events) == 1 - assert mock_connector._kv_cache_events is None + assert mock_connector._accumulated_worker_meta is None def test_multiple_workers_workflow(self, mock_connector): """Test workflow with multiple workers.""" @@ -603,8 +621,8 @@ def __init__(self, hash_val): MockEvent("hash_common"), MockEvent("hash_worker1"), ] - kv_events1 = mock_connector.get_kv_connector_kv_cache_events() - output1 = KVConnectorOutput(kv_cache_events=kv_events1) + worker_meta1 = mock_connector.build_connector_worker_meta() + output1 = KVConnectorOutput(kv_connector_worker_meta=worker_meta1) mock_connector.update_connector_output(output1) # Worker 2 @@ -612,35 +630,34 @@ def __init__(self, hash_val): MockEvent("hash_common"), MockEvent("hash_worker2"), ] - kv_events2 = mock_connector.get_kv_connector_kv_cache_events() - output2 = KVConnectorOutput(kv_cache_events=kv_events2) + worker_meta2 = mock_connector.build_connector_worker_meta() + output2 = KVConnectorOutput(kv_connector_worker_meta=worker_meta2) mock_connector.update_connector_output(output2) # Take events (should only get common events) taken_events = list(mock_connector.take_events()) # With aggregation, only events reported by both workers should be present - # In this case, hash_common was reported by both event_hashes = [e.block_hashes[0] for e in taken_events] assert "hash_common" in event_hashes def test_empty_workflow(self, mock_connector): """Test workflow when there are no events at any stage.""" - # Get events returns None + # Build worker meta returns None mock_connector._lmcache_engine.get_kv_events.return_value = None - kv_events = mock_connector.get_kv_connector_kv_cache_events() + worker_meta = mock_connector.build_connector_worker_meta() - assert kv_events is None + assert worker_meta is None # Update with None - output = KVConnectorOutput(kv_cache_events=None) + output = KVConnectorOutput(kv_connector_worker_meta=None) mock_connector.update_connector_output(output) # Take events taken_events = list(mock_connector.take_events()) assert taken_events == [] - assert mock_connector._kv_cache_events is None + assert mock_connector._accumulated_worker_meta is None def test_repeated_cycles(self, mock_connector): """Test multiple cycles of the complete workflow.""" @@ -656,14 +673,14 @@ def __init__(self, cycle_num): self.lora_name = None for cycle in range(3): - # Get events + # Build worker meta mock_connector._lmcache_engine.get_kv_events.return_value = [ MockEvent(cycle) ] - kv_events = mock_connector.get_kv_connector_kv_cache_events() + worker_meta = mock_connector.build_connector_worker_meta() # Update - output = KVConnectorOutput(kv_cache_events=kv_events) + output = KVConnectorOutput(kv_connector_worker_meta=worker_meta) mock_connector.update_connector_output(output) # Take @@ -672,11 +689,11 @@ def __init__(self, cycle_num): # Verify assert len(taken_events) == 1 assert taken_events[0].block_hashes[0] == f"hash_cycle_{cycle}" - assert mock_connector._kv_cache_events is None + assert mock_connector._accumulated_worker_meta is None - def test_lmcache_kv_events_aggregation(self): + def test_cross_worker_aggregation_via_kv_output_aggregator(self): """ - Test LMCacheKVEvents aggregation across TP ranks using + Test LMCacheWorkerMetadata aggregation across TP ranks using KVOutputAggregator (used by MultiprocExecutor). """ from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator @@ -726,57 +743,50 @@ def test_lmcache_kv_events_aggregation(self): lora_name=None, ) - # Create events for each worker - # Worker 0: reports common event and its unique event - worker0_events = LMCacheKVEvents(num_workers=1) - worker0_events.add_events([common_event, worker1_unique_event]) - - # Worker 1: reports common event and its unique event - worker1_events = LMCacheKVEvents(num_workers=1) - worker1_events.add_events([common_event, worker2_unique_event]) - - # Worker 2: reports common event and its unique event - worker2_events = LMCacheKVEvents(num_workers=1) - worker2_events.add_events([common_event, worker3_unique_event]) + # Create worker metadata for each worker + worker0_meta = LMCacheWorkerMetadata( + kv_events=[common_event, worker1_unique_event], num_workers=1 + ) + worker1_meta = LMCacheWorkerMetadata( + kv_events=[common_event, worker2_unique_event], num_workers=1 + ) + worker2_meta = LMCacheWorkerMetadata( + kv_events=[common_event, worker3_unique_event], num_workers=1 + ) # Create ModelRunnerOutput instances for each worker worker_outputs = [] - for i, worker_events in enumerate( - [worker0_events, worker1_events, worker2_events] - ): + for i, worker_meta in enumerate([worker0_meta, worker1_meta, worker2_meta]): output = ModelRunnerOutput( req_ids=[f"req_{i}"], req_id_to_index={f"req_{i}": 0}, - sampled_token_ids=[[123]], # dummy token + sampled_token_ids=[[123]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[None], kv_connector_output=KVConnectorOutput( - finished_sending=set([f"req_{i}_send"]) - if i < 2 - else None, # Workers 0,1 finished sending - finished_recving=set([f"req_{i}_recv"]) - if i > 0 - else None, # Workers 1,2 finished receiving - kv_cache_events=worker_events, + finished_sending=set([f"req_{i}_send"]) if i < 2 else None, + finished_recving=set([f"req_{i}_recv"]) if i > 0 else None, + kv_connector_worker_meta=worker_meta, ), ) worker_outputs.append(output) - # Use the real aggregation mechanism (like MultiprocExecutor.execute_model) + # Use the real aggregation mechanism aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0) - kv_cache_events = aggregated_output.kv_connector_output.kv_cache_events + aggregated_meta = aggregated_output.kv_connector_output.kv_connector_worker_meta + + assert isinstance(aggregated_meta, LMCacheWorkerMetadata) + assert aggregated_meta.num_workers == 3 - assert isinstance(kv_cache_events, LMCacheKVEvents) + # Now simulate scheduler-side: create a connector mock and use it + from vllm.distributed.kv_events import KVEventAggregator - # After aggregation, events should be combined from all workers - # The aggregator doesn't automatically aggregate events, so we need to call - # aggregate() to get only common events - kv_cache_events.aggregate() - aggregated_events = kv_cache_events.get_all_events() + agg = KVEventAggregator(aggregated_meta.num_workers) + agg.add_events(aggregated_meta.kv_events) + aggregated_events = agg.get_common_events() - # Only the common event should remain after aggregation - # because it's the only event reported by all 3 workers + # Only the common event should remain after consensus aggregation assert len(aggregated_events) == 1 assert aggregated_events[0] == common_event diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 6acc486292a1..abbada19d840 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -759,12 +759,9 @@ def test_multi_connector_overrides_all_base_methods(): Ensure MultiConnector overrides all public methods from KVConnectorBase_V1. """ # These are fine to inherit from KVConnectorBase_V1 - # TODO(https://github.com/vllm-project/vllm/pull/31811): Remove - # get_kv_connector_kv_cache_events from INHERITED_OK once implemented. INHERITED_OK = { "role", "has_connector_metadata", - "get_kv_connector_kv_cache_events", } base_members = { diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 21ec7a36e984..234e04d272e6 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -179,41 +179,6 @@ def __repr__(self) -> str: ) -class KVConnectorKVEvents(ABC): - """ - Abstract base class for KV events. - Acts as a container for KV events from the connector. - """ - - @abstractmethod - def add_events(self, events: list[KVCacheEvent]) -> None: - raise NotImplementedError - - @abstractmethod - def aggregate(self) -> "KVConnectorKVEvents": - raise NotImplementedError - - @abstractmethod - def increment_workers(self, count: int = 1) -> None: - raise NotImplementedError - - @abstractmethod - def get_all_events(self) -> list[KVCacheEvent]: - raise NotImplementedError - - @abstractmethod - def get_number_of_workers(self) -> int: - raise NotImplementedError - - @abstractmethod - def clear_events(self) -> None: - raise NotImplementedError - - def merge(self, other: "KVConnectorKVEvents") -> "KVConnectorKVEvents": - self.add_events(other.get_all_events()) - return self - - class EventPublisher(ABC): """Lightweight publisher for EventBatch batches with data parallelism support. diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 155395e84e11..e8ca62916dae 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -86,7 +86,6 @@ def update_finished_set( finished_recving = set[str]() aggregated_kv_connector_stats = None aggregated_kv_connector_worker_meta = None - combined_kv_cache_events = None invalid_block_ids = set[int]() for model_runner_output in outputs: assert model_runner_output is not None @@ -139,19 +138,6 @@ def update_finished_set( ) ) - # Combine kv_cache_events from all workers. - if combined_kv_cache_events is None: - # Use the first worker's kv_cache events as start event list. - combined_kv_cache_events = kv_output.kv_cache_events - elif kv_cache_events := kv_output.kv_cache_events: - assert isinstance( - combined_kv_cache_events, - type(kv_cache_events), - ) - worker_kv_cache_events = kv_cache_events.get_all_events() - combined_kv_cache_events.add_events(worker_kv_cache_events) - combined_kv_cache_events.increment_workers(1) - invalid_block_ids |= kv_output.invalid_block_ids # select output of the worker specified by output_rank @@ -162,7 +148,6 @@ def update_finished_set( finished_sending=finished_sending or None, finished_recving=finished_recving or None, kv_connector_stats=aggregated_kv_connector_stats or None, - kv_cache_events=combined_kv_cache_events or None, kv_connector_worker_meta=aggregated_kv_connector_worker_meta or None, invalid_block_ids=invalid_block_ids, expected_finished_count=self._expected_finished_count, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 2abbe6bf610a..b4bfccba4b2d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -54,7 +54,7 @@ if TYPE_CHECKING: from vllm.config import VllmConfig - from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents + from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorPromMetrics, KVConnectorStats, @@ -412,14 +412,6 @@ def get_kv_connector_stats(self) -> "KVConnectorStats | None": """ return None - def get_kv_connector_kv_cache_events(self) -> "KVConnectorKVEvents | None": - """ - Get the KV connector kv cache events collected during the last interval. - This function should be called by the model runner every time after the - model execution and before cleanup. - """ - return None - def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None: """ Get the KVConnector handshake metadata for this connector. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 64aee2bd9c49..ef0ef220a246 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any import torch @@ -9,13 +10,13 @@ from vllm.distributed.kv_events import ( BlockStored, KVCacheEvent, - KVConnectorKVEvents, KVEventAggregator, ) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + KVConnectorWorkerMetadata, ) from vllm.logger import init_logger from vllm.v1.attention.backend import AttentionMetadata @@ -31,42 +32,19 @@ logger = init_logger(__name__) -class LMCacheKVEvents(KVConnectorKVEvents): - """ - Concrete implementation of KVConnectorKVEvents using KVEventAggregator. - """ +@dataclass +class LMCacheWorkerMetadata(KVConnectorWorkerMetadata): + """Worker metadata for LMCache connector.""" - def __init__(self, num_workers: int) -> None: - self._aggregator = KVEventAggregator(num_workers) + kv_events: list[KVCacheEvent] = field(default_factory=list) + num_workers: int = 1 - def add_events(self, events: list[KVCacheEvent]) -> None: - self._aggregator.add_events(events) - - def aggregate(self) -> "LMCacheKVEvents": - """ - Aggregate KV events and retain only common events. - """ - common_events = self._aggregator.get_common_events() - self._aggregator.clear_events() - self._aggregator.add_events(common_events) - self._aggregator.reset_workers() - return self - - def increment_workers(self, count: int = 1) -> None: - self._aggregator.increment_workers(count) - - def get_all_events(self) -> list[KVCacheEvent]: - return self._aggregator.get_all_events() - - def get_number_of_workers(self) -> int: - return self._aggregator.get_number_of_workers() - - def clear_events(self) -> None: - self._aggregator.clear_events() - self._aggregator.reset_workers() - - def __repr__(self) -> str: - return f"" + def aggregate(self, other: KVConnectorWorkerMetadata) -> "LMCacheWorkerMetadata": + assert isinstance(other, LMCacheWorkerMetadata) + return LMCacheWorkerMetadata( + kv_events=self.kv_events + other.kv_events, + num_workers=self.num_workers + other.num_workers, + ) class LMCacheConnectorV1(KVConnectorBase_V1): @@ -112,7 +90,8 @@ def __init__( self._lmcache_engine = cls(vllm_config, role, self) - self._kv_cache_events: LMCacheKVEvents | None = None + # Accumulated worker metadata across steps (scheduler-side). + self._accumulated_worker_meta: LMCacheWorkerMetadata | None = None # ============================== # Worker-side methods @@ -227,11 +206,10 @@ def get_block_ids_with_load_errors(self) -> set[int]: # Fallback for older versions that don't support this method return set() - def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None: + def build_connector_worker_meta(self) -> LMCacheWorkerMetadata | None: """ - Get the KV connector kv cache events collected during the last interval. + Build worker metadata from this step. """ - events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined] if not events: return None @@ -249,9 +227,7 @@ def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None: for e in events ] - lmcache_kv_events = LMCacheKVEvents(num_workers=1) - lmcache_kv_events.add_events(blocks) - return lmcache_kv_events + return LMCacheWorkerMetadata(kv_events=blocks, num_workers=1) # ============================== # Scheduler-side methods @@ -308,17 +284,15 @@ def update_connector_output(self, connector_output: KVConnectorOutput): connector_output (KVConnectorOutput): the worker-side connectors output. """ - # Get the KV events - kv_cache_events = connector_output.kv_cache_events - if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents): + worker_meta = connector_output.kv_connector_worker_meta + if not worker_meta or not isinstance(worker_meta, LMCacheWorkerMetadata): return - if self._kv_cache_events is None: - self._kv_cache_events = kv_cache_events + if self._accumulated_worker_meta is None: + self._accumulated_worker_meta = worker_meta else: - self._kv_cache_events.add_events(kv_cache_events.get_all_events()) - self._kv_cache_events.increment_workers( - kv_cache_events.get_number_of_workers() + self._accumulated_worker_meta = self._accumulated_worker_meta.aggregate( + worker_meta ) return @@ -346,9 +320,9 @@ def take_events(self) -> Iterable["KVCacheEvent"]: Yields: New KV cache events since the last call. """ - if self._kv_cache_events is not None: - self._kv_cache_events.aggregate() - kv_cache_events = self._kv_cache_events.get_all_events() - yield from kv_cache_events - self._kv_cache_events.clear_events() - self._kv_cache_events = None + if self._accumulated_worker_meta is not None: + # Consensus aggregation: only keep events reported by all workers. + aggregator = KVEventAggregator(self._accumulated_worker_meta.num_workers) + aggregator.add_events(self._accumulated_worker_meta.kv_events) + yield from aggregator.get_common_events() + self._accumulated_worker_meta = None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 7cc80129a3a1..7899d0ab5115 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -337,12 +337,6 @@ def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None: return None return MultiKVConnectorWorkerMetadata(metadata=tuple(metadata_list)) - # TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events' - # method for the MultiConnector. It should be able to get events from - # multiple connectors, handling the case where only a subset of the - # requested connectors implements the 'get_kv_connector_kv_cache_events' - # WIP: https://github.com/vllm-project/vllm/pull/31811 - # ============================== # Scheduler-side methods # ============================== diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 8eb58de4f3fd..830f313540d1 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -13,7 +13,6 @@ from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: - from vllm.distributed.kv_events import KVConnectorKVEvents from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorWorkerMetadata, ) @@ -21,7 +20,6 @@ else: KVConnectorStats = object KVConnectorWorkerMetadata = object - KVConnectorKVEvents = object class LogprobsLists(NamedTuple): @@ -145,7 +143,6 @@ class KVConnectorOutput: finished_sending: set[str] | None = None finished_recving: set[str] | None = None kv_connector_stats: KVConnectorStats | None = None - kv_cache_events: KVConnectorKVEvents | None = None kv_connector_worker_meta: KVConnectorWorkerMetadata | None = None # IDs of externally computed KV blocks that failed to load. # Requests referencing these blocks should be rescheduled to recompute them @@ -162,7 +159,6 @@ def is_empty(self): not self.finished_sending and not self.finished_recving and not self.kv_connector_stats - and not self.kv_cache_events and not self.invalid_block_ids and not self.kv_connector_worker_meta ) diff --git a/vllm/v1/worker/gpu/kv_connector.py b/vllm/v1/worker/gpu/kv_connector.py index 7e4e27e1f234..ea5ea563f82f 100644 --- a/vllm/v1/worker/gpu/kv_connector.py +++ b/vllm/v1/worker/gpu/kv_connector.py @@ -93,7 +93,6 @@ def post_forward( ) output.invalid_block_ids = self.kv_connector.get_block_ids_with_load_errors() output.kv_connector_stats = self.kv_connector.get_kv_connector_stats() - output.kv_cache_events = self.kv_connector.get_kv_connector_kv_cache_events() if clear_metadata: self.kv_connector.clear_connector_metadata() return output diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 2921594a3b42..387d98dcbdc4 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -122,7 +122,6 @@ def _get_kv_connector_output( output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors() output.kv_connector_stats = kv_connector.get_kv_connector_stats() - output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events() output.kv_connector_worker_meta = kv_connector.build_connector_worker_meta() if not defer_finalize: