diff --git a/vllm_ascend/distributed/kvpool/pool_scheduler.py b/vllm_ascend/distributed/kvpool/pool_scheduler.py index 67a3b68e1b3..753a304226c 100644 --- a/vllm_ascend/distributed/kvpool/pool_scheduler.py +++ b/vllm_ascend/distributed/kvpool/pool_scheduler.py @@ -162,13 +162,14 @@ def build_connector_meta( self._request_trackers.pop(finished_req_id, None) self._unfinished_requests.pop(finished_req_id, None) self._unfinished_request_ids.discard(finished_req_id) - + for req_id in scheduler_output.preempted_req_ids: self._preempted_req_ids.update(scheduler_output.preempted_req_ids) self._request_trackers.pop(req_id, None) self._unfinished_requests.pop(req_id, None) - meta = AscendConnectorMetadata(self._unfinished_request_ids, scheduler_output.preempted_req_ids) + meta = AscendConnectorMetadata(self._unfinished_request_ids, + scheduler_output.preempted_req_ids) for request in scheduler_output.scheduled_new_reqs: # Right now, we only load KV for new requests @@ -183,17 +184,17 @@ def build_connector_meta( else: unfolded_block_ids = request.block_ids[0].copy() request_tracker = RequestTracker( - req_id=request.req_id, - token_len=num_tokens_to_compute, - allocated_block_ids=unfolded_block_ids, - num_saved_tokens=0, - ) + req_id=request.req_id, + token_len=num_tokens_to_compute, + allocated_block_ids=unfolded_block_ids, + num_saved_tokens=0, + ) self._request_trackers[request.req_id] = request_tracker last_chunk_tokens_num = ((len(request.prompt_token_ids) // self._block_size * self._block_size) if self._discard_partial_chunks else len( request.prompt_token_ids)) - + req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, @@ -233,10 +234,11 @@ def build_connector_meta( num_saved_tokens=0, ) self._request_trackers[req_id] = request_tracker - last_chunk_tokens_num = ((len(request_real.prompt_token_ids) // - self._block_size * self._block_size) - if self._discard_partial_chunks else len( - request_real.prompt_token_ids)) + last_chunk_tokens_num = ( + (len(request_real.prompt_token_ids) // + self._block_size * + self._block_size) if self._discard_partial_chunks else + len(request_real.prompt_token_ids)) req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, @@ -247,17 +249,19 @@ def build_connector_meta( >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, ) - + # decode/chunked request else: request_tracker = self._request_trackers[req_id] - num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_new_tokens = scheduler_output.num_scheduled_tokens[ + req_id] req_tuple = self._unfinished_requests.get(req_id) if req_tuple: request = req_tuple[0] num_current_tokens = request_tracker.token_len new_token_ids = request.all_token_ids[ - num_current_tokens:num_current_tokens + num_new_tokens] + num_current_tokens:num_current_tokens + + num_new_tokens] request_tracker.token_len += len(new_token_ids) else: raise ValueError( @@ -269,9 +273,10 @@ def build_connector_meta( request_tracker.update(new_block_ids) last_chunk_tokens_num = ((len(request.prompt_token_ids) // - self._block_size * self._block_size) - if self._discard_partial_chunks else - len(request.prompt_token_ids)) + self._block_size * + self._block_size) if + self._discard_partial_chunks else + len(request.prompt_token_ids)) req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index 563dd6174c8..cc352198d9f 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -461,11 +461,13 @@ def store_layer( for layer_id in range(self.num_layers): yield - def get_finished(self, - finished_req_ids: set[str], meta:AscendConnectorMetadata) -> tuple[set[str], set[str]]: + def get_finished( + self, finished_req_ids: set[str], + meta: AscendConnectorMetadata) -> tuple[set[str], set[str]]: done_sending = ( self.get_and_clear_finished_requests( - finished_req_ids, meta # type: ignore[union-attr] + finished_req_ids, + meta # type: ignore[union-attr] ) if self.kv_role in ['kv_producer', 'kv_both'] or self.consumer_is_to_put else set()) @@ -480,7 +482,8 @@ def get_finished(self, self.tp_rank) return done_sending, done_recving - def get_and_clear_finished_requests(self, finished_req_ids, meta:AscendConnectorMetadata) -> set[str]: + def get_and_clear_finished_requests( + self, finished_req_ids, meta: AscendConnectorMetadata) -> set[str]: finished_sending = set() for req_id in meta.preempted_req_ids: self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]