diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index ea82322b00b..1758e3a1040 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -114,6 +114,7 @@ class ReqMeta: remote_cache_tokens: int = 0 local_computed_tokens: int = 0 local_transed_tokens: int = 0 + do_virtual: bool = False @dataclass @@ -587,6 +588,7 @@ def add_new_req( remote_tp_size=kv_transfer_params.get("remote_tp_size"), remote_pcp_size=kv_transfer_params.get("remote_pcp_size"), remote_dcp_size=kv_transfer_params.get("remote_dcp_size"), + do_virtual=kv_transfer_params.get("do_virtual"), chunk_finish=chunk_finish, remote_cache_tokens=remote_cache_tokens, local_computed_tokens=local_computed_tokens, @@ -763,6 +765,7 @@ def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: in def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): params = request.kv_transfer_params + do_virtual = params.get("do_virtual") logger.debug( "MooncakeLayerwiseConnector update_state_after_alloc: num_external_tokens=%s, kv_transfer_params=%s", num_external_tokens, @@ -804,16 +807,16 @@ def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", remote_dcp_size=self.vllm_config.parallel_config.decode_context_parallel_size, remote_cached_tokens=remote_cached_tokens, ) + if not do_virtual: + future = self.executor.submit( + self._access_metaserver, url=params.get("metaserver", None), message=kv_transfer_params + ) - future = self.executor.submit( - self._access_metaserver, url=params.get("metaserver", None), message=kv_transfer_params - ) - - def handle_exception(future): - if future.exception(): - logger.error(f"Access metaserver fail: {future.exception()}") + def handle_exception(future): + if future.exception(): + logger.error(f"Access metaserver fail: {future.exception()}") - future.add_done_callback(handle_exception) + future.add_done_callback(handle_exception) # Layerwise prefiller add request need send if params is not None and params.get("do_remote_decode"): @@ -1034,6 +1037,7 @@ def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig, engi # TODO(kunpengW-code): Reuse k_buffer, v_buffer self.k_quant_buffer: torch.Tensor | None = None self.v_quant_buffer: torch.Tensor | None = None + self.virtual_request: set[str] = set() def create_kv_buffer(self, first_kv_cache_tuple): alignment = 2 * 1024 * 1024 @@ -1231,6 +1235,8 @@ def get_finished(self) -> tuple[set[str], set[str]]: if self.vllm_config.kv_transfer_config.is_kv_consumer else set() ) + done_recving.update(self.virtual_request) + self.virtual_request = set() if len(done_recving) > 0: logger.info( f"Number of completed KV cache recv requests: {len(done_recving)}, receive requests: {done_recving}" @@ -1403,6 +1409,9 @@ def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): self.current_layer = 0 if self.vllm_config.kv_transfer_config.is_kv_consumer: for req_id, meta in metadata.requests.items(): + if meta.do_virtual: + self.virtual_request.add(req_id) + continue external_req_id = get_external_request_id(req_id) assert self.kv_recv_layer_thread is not None with self.kv_recv_layer_thread.lock: