diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py index a50f950c4cd..cce4c53d7d0 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py @@ -43,6 +43,7 @@ def __init__(self, vllm_config: "VllmConfig", use_layerwise): self._block_size *= self.dcp_size # request_id -> full_token_ids self._request_trackers: dict[str, RequestTracker] = {} + self._preempted_req_ids: set[str] = set() # Whether to discard partial chunks self._discard_partial_chunks = ( vllm_config.kv_transfer_config.get_from_extra_config( @@ -161,6 +162,11 @@ def build_connector_meta( self._request_trackers.pop(finished_req_id, None) self._unfinished_requests.pop(finished_req_id, None) self._unfinished_request_ids.discard(finished_req_id) + + for req_id in scheduler_output.preempted_req_ids: + self._preempted_req_ids.update(scheduler_output.preempted_req_ids) + self._request_trackers.pop(req_id, None) + self._unfinished_requests.pop(req_id, None) meta = AscendConnectorMetadata(self._unfinished_request_ids, scheduler_output.preempted_req_ids) @@ -170,15 +176,24 @@ def build_connector_meta( num_tokens_to_compute = ( request.num_computed_tokens + scheduler_output.num_scheduled_tokens[request.req_id]) - request_tracker = RequestTracker.from_new_request( - request, num_tokens_to_compute) + request_tuple = self._unfinished_requests.get(request.req_id) + request_real = request_tuple[0] # type: ignore[index] + if not isinstance(request.block_ids[0], list): + unfolded_block_ids = request.block_ids.copy() + else: + unfolded_block_ids = request.block_ids[0].copy() + request_tracker = RequestTracker( + req_id=request.req_id, + token_len=num_tokens_to_compute, + allocated_block_ids=unfolded_block_ids, + num_saved_tokens=0, + ) self._request_trackers[request.req_id] = request_tracker last_chunk_tokens_num = ((len(request.prompt_token_ids) // self._block_size * self._block_size) if self._discard_partial_chunks else len( request.prompt_token_ids)) - request_tuple = self._unfinished_requests.get(request.req_id) - request_real = request_tuple[0] # type: ignore[index] + req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, @@ -195,38 +210,78 @@ def build_connector_meta( cached_reqs = scheduler_output.scheduled_cached_reqs if not force_skip_save: for i, req_id in enumerate(cached_reqs.req_ids): - request_tracker = self._request_trackers[req_id] - num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] - req_tuple = self._unfinished_requests.get(req_id) - if req_tuple: - request = req_tuple[0] - num_current_tokens = request_tracker.token_len - new_token_ids = request.all_token_ids[ - num_current_tokens:num_current_tokens + num_new_tokens] - request_tracker.token_len += len(new_token_ids) - else: - raise ValueError( - f"Request {req_id} is not in _unfinished_requests, " - f"but it is scheduled to be cached") + # resumed request new_block_ids = cached_reqs.new_block_ids[i] if not new_block_ids: continue - request_tracker.update(new_block_ids) - - last_chunk_tokens_num = ((len(request.prompt_token_ids) // - self._block_size * self._block_size) - if self._discard_partial_chunks else - len(request.prompt_token_ids)) - req_meta = ReqMeta.from_request_tracker( - request_tracker, - self._block_size, - load_spec=None, - skip_save=force_skip_save, - block_hashes=request.block_hashes, - is_last_chunk=request_tracker.token_len - >= last_chunk_tokens_num, - discard_partial_chunks=self._discard_partial_chunks, - ) + if req_id in self._preempted_req_ids: + if isinstance(new_block_ids, tuple): + new_block_ids = new_block_ids[0].copy() + else: + new_block_ids = new_block_ids.copy() + self._preempted_req_ids.discard(req_id) + load_spec = self.load_specs.pop(req_id, None) + request_tuple = self._unfinished_requests.get(req_id) + request_real = request_tuple[0] # type: ignore[index] + num_tokens_to_compute = ( + request_real.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + request_tracker = RequestTracker( + req_id=req_id, + token_len=num_tokens_to_compute, + allocated_block_ids=new_block_ids, + num_saved_tokens=0, + ) + self._request_trackers[req_id] = request_tracker + last_chunk_tokens_num = ((len(request_real.prompt_token_ids) // + self._block_size * self._block_size) + if self._discard_partial_chunks else len( + request_real.prompt_token_ids)) + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=load_spec, + skip_save=force_skip_save, + block_hashes=request_real.block_hashes, + is_last_chunk=request_tracker.token_len + >= last_chunk_tokens_num, + discard_partial_chunks=self._discard_partial_chunks, + ) + + # decode/chunked request + else: + request_tracker = self._request_trackers[req_id] + num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] + req_tuple = self._unfinished_requests.get(req_id) + if req_tuple: + request = req_tuple[0] + num_current_tokens = request_tracker.token_len + new_token_ids = request.all_token_ids[ + num_current_tokens:num_current_tokens + num_new_tokens] + request_tracker.token_len += len(new_token_ids) + else: + raise ValueError( + f"Request {req_id} is not in _unfinished_requests, " + f"but it is scheduled to be cached") + num_computed_token = cached_reqs.num_computed_tokens[i] + if num_computed_token >= len(request.prompt_token_ids): + continue + request_tracker.update(new_block_ids) + + last_chunk_tokens_num = ((len(request.prompt_token_ids) // + self._block_size * self._block_size) + if self._discard_partial_chunks else + len(request.prompt_token_ids)) + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=None, + skip_save=force_skip_save, + block_hashes=request.block_hashes, + is_last_chunk=request_tracker.token_len + >= last_chunk_tokens_num, + discard_partial_chunks=self._discard_partial_chunks, + ) if req_meta is not None: meta.add_request(req_meta)