Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class ReqMeta:
remote_host: str
remote_port: int
remote_engine_id: str
remote_request_id: str
remote_pcp_size: int
remote_dcp_size: int
remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]]
Expand Down Expand Up @@ -375,6 +376,7 @@ def __init__(self,

def add_request(self,
request_id: str,
remote_request_id: str,
local_block_ids: list[int],
remote_block_ids: list[int],
remote_engine_id: str,
Expand All @@ -391,6 +393,7 @@ def add_request(self,
"local_block_ids": local_block_ids,
"remote_block_ids": remote_block_ids,
"remote_engine_id": remote_engine_id,
"remote_request_id": remote_request_id,
"remote_host": remote_host,
"remote_handshake_port": remote_handshake_port,
"offset": offset,
Expand Down Expand Up @@ -423,21 +426,22 @@ def run(self):

def _handle_request(self, req_meta: dict[str, Any]):
request_id = req_meta["request_id"]
remote_request_id = req_meta["remote_request_id"]
remote_host = req_meta["remote_host"]
remote_handshake_port = req_meta["remote_handshake_port"]
remote_port_send_num = req_meta["remote_port_send_num"]
all_task_done = req_meta["all_task_done"]

try:
logger.debug(
f"Starting to transfer KV cache for request {request_id}.")
f"Starting to transfer KV cache for request {remote_request_id}.")
self._transfer_kv_cache(req_meta)
logger.debug(
f"Finished transferring KV cache for request {request_id}.")
f"Finished transferring KV cache for request {remote_request_id}.")
except Exception as e:
logger.error(
"Failed to transfer KV cache for request "
f"{request_id}: {e}",
f"{remote_request_id}: {e}",
exc_info=True)
finally:
if all_task_done:
Expand All @@ -448,10 +452,10 @@ def _handle_request(self, req_meta: dict[str, Any]):
# Always send the done signal to the remote host to ensure proper
# resource cleanup. Failing to do so may cause a memory leak on the
# remote host.
self._send_done_recv_signal(request_id, remote_host,
self._send_done_recv_signal(remote_request_id, remote_host,
remote_handshake_port,
remote_port_send_num)
self._send_done_signal_to_free_remote_port(request_id, remote_host,
self._send_done_signal_to_free_remote_port(remote_request_id, remote_host,
remote_port_send_num)

def _send_done_signal_to_free_remote_port(self, request_id, remote_host,
Expand All @@ -472,7 +476,7 @@ def _send_done_signal_to_free_remote_port(self, request_id, remote_host,

def _transfer_kv_cache(self, req_meta: dict[str, Any]):
"""Handle a KV cache transfer request."""
request_id = req_meta["request_id"]
remote_request_id = req_meta["remote_request_id"]
remote_block_ids = req_meta["remote_block_ids"]
local_block_ids = req_meta["local_block_ids"]
remote_engine_id = req_meta["remote_engine_id"]
Expand Down Expand Up @@ -558,15 +562,15 @@ def _transfer_kv_cache(self, req_meta: dict[str, Any]):
dst_list, length_list)
if ret < 0:
logger.error("Mooncake transfer failed for request %s",
req_meta["request_id"])
req_meta["remote_request_id"])
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")

req_end_time = time.perf_counter()
req_transfer_elapsed = (req_end_time - req_start_time) * 1000
logger.info(
"KV cache transfer for request %s took %.2f ms (%d groups,"
" %d blocks). local_ip %s local_device_id %s remote_session_id %s",
request_id, req_transfer_elapsed, num_transfer_groups, num_blocks,
remote_request_id, req_transfer_elapsed, num_transfer_groups, num_blocks,
get_ip(), self.tp_rank, session_id)

# Determine if the current position is the offset position at the end of
Expand Down Expand Up @@ -791,6 +795,7 @@ def add_new_req(
num_external_tokens=num_external_tokens,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_request_id=kv_transfer_params["remote_request_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
Expand Down Expand Up @@ -996,7 +1001,7 @@ def update_state_after_alloc(self, request: "Request",
if params is not None and params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
"remote_port", "remote_request_id")):
local_block_ids = (blocks.get_unhashed_block_ids()
if num_external_tokens > 0 else [])
# Get unhashed blocks to pull from remote.
Expand Down Expand Up @@ -1074,6 +1079,7 @@ def request_finished(
do_remote_decode=False,
remote_block_ids=computed_block_ids,
remote_engine_id=self.engine_id,
remote_request_id=request.request_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
remote_pcp_size=self.pcp_size,
Expand Down Expand Up @@ -1583,6 +1589,7 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
meta.remote_multi_nodes_meta_mapping)
self.kv_recv_thread.add_request(
request_id=req_id,
remote_request_id=meta.remote_request_id,
local_block_ids=local_block_ids_list[pcp_dcp_rank],
remote_block_ids=remote_block_ids_list[
pcp_dcp_rank],
Expand Down Expand Up @@ -1610,6 +1617,7 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
meta.remote_multi_nodes_meta_mapping)
self.kv_recv_thread.add_request(
request_id=req_id,
remote_request_id=meta.remote_request_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_engine_id=remote_engine_id,
Expand Down