Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm_ascend/distributed/kvpool/ascend_store_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def get_finished(self,
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
done_sending, done_recving = self.connector_worker.get_finished(
finished_req_ids)
finished_req_ids, self._get_connector_metadata())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The self._get_connector_metadata() method is inherited from KVConnectorBase_V1. The get_finished method in KVPoolWorker now explicitly expects an AscendConnectorMetadata object, which includes preempted_req_ids. If the inherited _get_connector_metadata() returns a generic KVConnectorMetadata (which it likely does, as AscendConnectorMetadata is a specific implementation), accessing meta.preempted_req_ids in KVPoolWorker will result in an AttributeError. To ensure type safety and correct functionality, AscendStoreConnector should override _get_connector_metadata to explicitly return an AscendConnectorMetadata instance, ensuring it contains the necessary preempted_req_ids from the scheduler.

return done_sending, done_recving


Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/distributed/kvpool/config_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,10 @@ def from_request_tracker(

class AscendConnectorMetadata(KVConnectorMetadata):

def __init__(self, unfinished_request_ids):
def __init__(self, unfinished_request_ids, preempted_req_ids):
self.requests = []
self.unfinished_request_ids = unfinished_request_ids
self.preempted_req_ids = preempted_req_ids

def add_request(self, req_meta: ReqMeta) -> None:
"""Add a request to the metadata.
Expand Down
18 changes: 12 additions & 6 deletions vllm_ascend/distributed/kvpool/kv_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ def add_stored_request(self, req_id: str):
with self.done_task_lock:
self.stored_requests[req_id] += 1

def dec_stored_request(self, req_id: str):
with self.done_task_lock:
if req_id in self.stored_requests:
self.stored_requests[req_id] -= 1

def delete_finished_stored_request(self, req_id: str):
with self.done_task_lock:
if req_id in self.stored_requests:
Expand All @@ -129,6 +134,10 @@ def _handle_request(self, req_meta: ReqMeta):
starts = []
ends = []
keys = []
if req_id not in self.stored_requests:
self.request_queue.task_done()
return
Comment on lines +137 to +139
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check if req_id not in self.stored_requests: might not behave as expected. self.stored_requests is a defaultdict(int). If a request's count has been decremented to 0 by dec_stored_request but the key has not been explicitly removed by delete_finished_stored_request, the req_id will still be present in self.stored_requests (with a value of 0). In such a case, this condition would be False, and the method would proceed to process a request that has no pending blocks to store, potentially leading to incorrect behavior or resource issues. Consider checking if self.stored_requests.get(req_id, 0) <= 0: to correctly identify requests that are logically finished from the sending perspective.

Suggested change
if req_id not in self.stored_requests:
self.request_queue.task_done()
return
if self.stored_requests.get(req_id, 0) <= 0:
self.request_queue.task_done()
return


for start, end, key in self.token_database.process_tokens(
token_len, req_meta.block_hashes):
starts.append(start)
Expand All @@ -141,15 +150,13 @@ def _handle_request(self, req_meta: ReqMeta):
keys = keys[self.tp_rank % self.put_step::self.put_step]

if not keys:
with self.done_task_lock:
self.stored_requests[req_id] -= 1
self.dec_stored_request(req_id)
return

skip_block_num = self.lookup(keys)

if skip_block_num == len(keys):
with self.done_task_lock:
self.stored_requests[req_id] -= 1
self.dec_stored_request(req_id)
return

starts = starts[skip_block_num:]
Expand Down Expand Up @@ -188,8 +195,7 @@ def _handle_request(self, req_meta: ReqMeta):
current_event.synchronize()
self.m_store.put(keys, addrs, sizes)

with self.done_task_lock:
self.stored_requests[req_id] -= 1
self.dec_stored_request(req_id)
self.request_queue.task_done()


Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/distributed/kvpool/pool_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def build_connector_meta(
self._unfinished_requests.pop(finished_req_id, None)
self._unfinished_request_ids.discard(finished_req_id)

meta = AscendConnectorMetadata(self._unfinished_request_ids)
meta = AscendConnectorMetadata(self._unfinished_request_ids, scheduler_output.preempted_req_ids)

for request in scheduler_output.scheduled_new_reqs:
# Right now, we only load KV for new requests
Expand Down
9 changes: 6 additions & 3 deletions vllm_ascend/distributed/kvpool/pool_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,10 @@ def store_layer(
yield

def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
finished_req_ids: set[str], meta:AscendConnectorMetadata) -> tuple[set[str], set[str]]:
done_sending = (
self.get_and_clear_finished_requests(
finished_req_ids # type: ignore[union-attr]
finished_req_ids, meta # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both']
or self.consumer_is_to_put else set())

Expand All @@ -480,8 +480,11 @@ def get_finished(self,
self.tp_rank)
return done_sending, done_recving

def get_and_clear_finished_requests(self, finished_req_ids) -> set[str]:
def get_and_clear_finished_requests(self, finished_req_ids, meta:AscendConnectorMetadata) -> set[str]:
finished_sending = set()
for req_id in meta.preempted_req_ids:
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id)
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
):
if self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
Expand Down
Loading