diff --git a/tests/v1/kv_connector/unit/test_mooncake_store_scheduler.py b/tests/v1/kv_connector/unit/test_mooncake_store_scheduler.py index 5ee4620d5ace..ac36005c63ef 100644 --- a/tests/v1/kv_connector/unit/test_mooncake_store_scheduler.py +++ b/tests/v1/kv_connector/unit/test_mooncake_store_scheduler.py @@ -16,7 +16,6 @@ def _make_bare_scheduler() -> MooncakeStoreScheduler: scheduler = object.__new__(MooncakeStoreScheduler) scheduler.kv_role = "kv_both" - scheduler.original_block_size = 16 scheduler._block_size = 16 scheduler.load_specs = {} scheduler._preempted_req_ids = set() diff --git a/tests/v1/kv_connector/unit/test_mooncake_store_worker.py b/tests/v1/kv_connector/unit/test_mooncake_store_worker.py index 375aad4eeb8f..d709d1f5971f 100644 --- a/tests/v1/kv_connector/unit/test_mooncake_store_worker.py +++ b/tests/v1/kv_connector/unit/test_mooncake_store_worker.py @@ -134,7 +134,6 @@ def _make_store_req(req_id: str, block_hashes: list[bytes]) -> ReqMeta: block_ids=([0, 1],), block_hashes=block_hashes, can_save=True, - original_block_size=16, ) @@ -970,7 +969,6 @@ def test_store_sending_thread_clamps_token_len_to_lcm(): block_ids=([0, 1, 2],), block_hashes=[b"a0", b"a1", b"a2"], can_save=True, - original_block_size=16, ) ) @@ -1010,7 +1008,6 @@ def test_store_sending_thread_skips_when_token_len_below_lcm(): block_ids=([0, 1],), block_hashes=[b"a0", b"a1"], can_save=True, - original_block_size=64, ) ) @@ -1088,7 +1085,6 @@ def test_store_sending_thread_only_stores_swa_blocks_in_window(): block_ids=([0, 1], list(range(8))), block_hashes=hs, can_save=True, - original_block_size=32, ) ) @@ -1104,6 +1100,84 @@ def test_store_sending_thread_only_stores_swa_blocks_in_window(): assert swa_hashes == {hs[3].hex(), hs[7].hex()} +def test_store_sending_thread_kv_events_use_group_chunk_metadata(): + from vllm.v1.core.kv_cache_utils import BlockHash, maybe_convert_block_hash + from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheGroupSpec, + SlidingWindowSpec, + ) + + store = MagicMock() + store.batch_is_exist.side_effect = lambda keys: [0] * len(keys) + store.batch_put_from_multi_buffers.return_value = [256, 256] + + full_spec = FullAttentionSpec( + block_size=32, num_kv_heads=8, head_size=64, dtype=None + ) + swa_spec = SlidingWindowSpec( + block_size=8, + num_kv_heads=8, + head_size=64, + dtype=None, + sliding_window=8, + ) + coord = mooncake_store_worker.MooncakeStoreCoordinator( + [KVCacheGroupSpec(["L0"], full_spec), KVCacheGroupSpec(["L1"], swa_spec)], + scheduler_block_size=32, + hash_block_size=8, + ) + + db_full = ChunkedTokenDatabase( + KeyMetadata("test-model", 0, 0, 0, 0, group_id=0), + block_size=32, + hash_block_size=8, + ) + db_full.set_kv_caches_base_addr([0x1000]) + db_full.set_block_len([512]) + db_swa = ChunkedTokenDatabase( + KeyMetadata("test-model", 0, 0, 0, 0, group_id=1), + block_size=8, + hash_block_size=8, + ) + db_swa.set_kv_caches_base_addr([0x2000]) + db_swa.set_block_len([128]) + + thread = _make_store_sending_thread( + store, + coord=coord, + token_databases=[db_full, db_swa], + block_size=32, + ) + thread.enable_kv_event = True + + hs = [bytes([i + 1]) * 4 for i in range(4)] + thread.add_stored_request("r0") + thread._handle_request( + ReqMeta( + req_id="r0", + token_len_chunk=32, + block_ids=([0], list(range(4))), + block_hashes=hs, + can_save=True, + token_ids=list(range(32)), + ) + ) + + full_event, swa_event = thread.get_kv_events() + assert full_event.group_idx == 0 + assert full_event.block_size == 32 + assert full_event.token_ids == list(range(32)) + assert full_event.block_hashes == [ + maybe_convert_block_hash(BlockHash(b"".join(hs))) + ] + + assert swa_event.group_idx == 1 + assert swa_event.block_size == 8 + assert swa_event.token_ids == list(range(24, 32)) + assert swa_event.block_hashes == [maybe_convert_block_hash(BlockHash(hs[3]))] + + def _auto_set_ready_event(*args, **kwargs): """Side effect for mocked thread constructors that auto-sets ready_event.""" for arg in args: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/data.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/data.py index 2a625c062779..b26e6835a9c9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/data.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/data.py @@ -213,7 +213,6 @@ class ReqMeta: current_event: torch.cuda.Event | None = None token_ids: list[int] | None = None - original_block_size: int | None = None @staticmethod def from_request_tracker( @@ -223,7 +222,6 @@ def from_request_tracker( skip_save: bool | None = False, block_hashes: list[BlockHash] | None = None, is_last_chunk: bool | None = None, - original_block_size: int | None = None, ) -> "ReqMeta | None": """Create ReqMeta from a RequestTracker.""" if block_hashes is None: @@ -274,7 +272,6 @@ def from_request_tracker( block_hashes=block_hashes, is_last_chunk=is_last_chunk, token_ids=token_ids, - original_block_size=original_block_size, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/scheduler.py index 52bab591a9be..4c4d55df3e1e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/scheduler.py @@ -59,11 +59,7 @@ def __init__( ) self.client = LookupKeyClient(vllm_config) - self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size - self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size - self.original_block_size = vllm_config.cache_config.block_size - # LCM for multi-group HMA; bs * pcp * dcp for single-group. Matches - # the engine's own scheduler block size by construction. + # Align with the engine's own scheduler_block_size and hash_block_size. self._block_size, self._hash_block_size = resolve_kv_cache_block_sizes( kv_cache_config, vllm_config ) @@ -221,7 +217,6 @@ def build_connector_meta( skip_save=force_skip_save, block_hashes=request_real.block_hashes, is_last_chunk=(request_tracker.token_len >= last_chunk_tokens_num), - original_block_size=self.original_block_size, ) if req_meta is not None: meta.add_request(req_meta) @@ -274,7 +269,6 @@ def build_connector_meta( is_last_chunk=( request_tracker.token_len >= last_chunk_tokens_num ), - original_block_size=self.original_block_size, ) else: # Decode/chunked request @@ -312,7 +306,6 @@ def build_connector_meta( is_last_chunk=( request_tracker.token_len >= last_chunk_tokens_num ), - original_block_size=self.original_block_size, ) if req_meta is not None: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/worker.py index cd4eb5c37136..b3a40d8b45f4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/worker.py @@ -507,7 +507,8 @@ def _clear_store_pressure(self) -> bool: return True def _handle_request(self, req_meta: ReqMeta): - # Cache hits are always a multiple of ``lcm_block_size`` tokens + # Cache hits are always a multiple of ``lcm_block_size`` tokens, which + # is also ``store_mask``'s precondition. lcm_block_size = self.coord.lcm_block_size token_len = req_meta.token_len_chunk // lcm_block_size * lcm_block_size block_ids_per_group = req_meta.block_ids @@ -550,7 +551,7 @@ def _handle_request(self, req_meta: ReqMeta): starts.append(start) ends.append(end) keys.append(key.to_string()) - block_hashes.append(req_meta.block_hashes[chunk_idx]) + block_hashes.append(BlockHash(bytes.fromhex(key.chunk_hash))) group_indices.append(g_idx) # Apply put_step striding for TP @@ -627,10 +628,11 @@ def _handle_request(self, req_meta: ReqMeta): block_hashes=[new_block_hashes[idx]], parent_block_hash=prev_key_per_group.get(g_idx), token_ids=token_ids, - block_size=req_meta.original_block_size, + block_size=db.block_size, lora_id=None, medium="cpu", lora_name=None, + group_idx=g_idx, ) stored_events.append(stored_event) prev_key_per_group[g_idx] = new_block_hashes[idx] @@ -947,7 +949,6 @@ def __init__( "load_async", True ) self.cache_config = vllm_config.cache_config - self.original_block_size = self.cache_config.block_size self.block_size, self.hash_block_size = resolve_kv_cache_block_sizes( kv_cache_config, vllm_config )