Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions vllm_ascend/distributed/mooncake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,27 @@ def __init__(self):
# intentionally delayed. Each entry is a tuple of (request_id,
# timestamp). If a request remains in this queue for too long, it will
# be force-freed.
self.record_finished_requests: set[str] = set()
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
self.reqs_to_process: set[str] = set()

def add_req_to_process(self, request_id: str):
self.reqs_to_process.add(request_id)

def add_not_transfer_request(self, request_id: str):
with self.done_task_lock:
self.finished_requests.add(request_id)
self.reqs_to_process.discard(request_id)

def update_done_task_count(self, request_id: str):
with self.done_task_lock:
self.finished_requests.add(request_id)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will placing this line of code back in its original position cause any issues?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will placing this line of code back in its original position cause any issues?

For a P node, if the request has been forced_free, then it will not be in delayed_free_requests. This indicates it was previously marked as finished and does not need to be marked again.

if request_id in self.delayed_free_requests:
self._remove_delayed_requests(request_id)
if request_id in self.reqs_to_process:
self.finished_requests.add(request_id)
self.reqs_to_process.discard(request_id)
self.delayed_free_requests.pop(request_id, None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add an else branch to report the log error. An exception occurs when the req_id received by update_done_task_count is not in the process, indicating a precision issue.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add an else branch to report the log error. An exception occurs when the req_id received by update_done_task_count is not in the process, indicating a precision issue.

Yes, we need to remind users about this here.

else:
self.record_finished_requests.add(request_id)
logger.error(
"MooncakeConnector finish req not in reqs to process.If it is a P node, this request may have been force freed."
)

def get_and_clear_finished_requests(self) -> set[str]:
"""
Expand All @@ -140,10 +147,7 @@ def get_and_clear_finished_requests(self) -> set[str]:
def add_delayed_request(self, request_id: str, delay_start_time: float):
"""Add a delayed free request."""
with self.done_task_lock:
if request_id not in self.record_finished_requests:
self.delayed_free_requests[request_id] = delay_start_time
else:
self.record_finished_requests.discard(request_id)
self.delayed_free_requests[request_id] = delay_start_time

def _retrieve_expired_requests(self):
"""Retrieve all expired delayed requests."""
Expand All @@ -156,16 +160,13 @@ def _retrieve_expired_requests(self):
if (current_time - delay_start_time
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
self.delayed_free_requests.popitem(last=False)
self.reqs_to_process.discard(request_id)
expired_requests.add(request_id)
logger.info("Force freed request: %s", request_id)
else:
break
return expired_requests

def _remove_delayed_requests(self, request_id: str):
"""Remove all delayed free requests matching the given request_id."""
self.delayed_free_requests.pop(request_id)


class KVCacheSendingThread(threading.Thread):

Expand Down Expand Up @@ -769,6 +770,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.requests: dict[str, ReqMeta] = {}
self.requests_to_send: dict[str, float] = {}
self.reqs_in_batch: set[str] = set()

def add_new_req(
self,
Expand Down Expand Up @@ -932,6 +934,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[str, tuple[Request, list[int], int]] = {}
self._reqs_need_send: dict[str, float] = {}
self._reqs_in_batch: set[str] = set()

# master-slave meta information for cross-nodes
self.multi_nodes_meta_mapping: dict[str, dict[str, Any]] = {}
Expand Down Expand Up @@ -980,6 +983,9 @@ def update_state_after_alloc(self, request: "Request",
"num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens, params)

if params is not None and (params.get("do_remote_prefill", False)
or params.get("do_remote_decode", False)):
self._reqs_in_batch.add(request.request_id)
if params is not None and params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host",
Expand Down Expand Up @@ -1022,6 +1028,8 @@ def build_connector_meta(
self._reqs_need_recv.clear()
meta.requests_to_send = self._reqs_need_send
self._reqs_need_send = {}
meta.reqs_in_batch = self._reqs_in_batch
self._reqs_in_batch = set()

return meta

Expand Down Expand Up @@ -1601,6 +1609,12 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
all_task_done=(i == self.tp_num_need_pulls *
self._prefill_pp_size - 1))

for req_id in metadata.reqs_in_batch:
if self.kv_send_thread is not None:
self.kv_send_thread.task_tracker.add_req_to_process(req_id)
if self.kv_recv_thread is not None:
self.kv_recv_thread.task_tracker.add_req_to_process(req_id)

if self.kv_send_thread is not None and self.pcp_size * self.dcp_size == 1:
for req_id, delay_start_time in metadata.requests_to_send.items():
if self.tp_rank in self._prefill_get_remote_rank(req_id):
Expand Down
Loading