Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, vllm_config: "VllmConfig", use_layerwise):
self._block_size *= self.dcp_size
# request_id -> full_token_ids
self._request_trackers: dict[str, RequestTracker] = {}
self._preempted_req_ids: set[str] = set()
# Whether to discard partial chunks
self._discard_partial_chunks = (
vllm_config.kv_transfer_config.get_from_extra_config(
Expand Down Expand Up @@ -161,6 +162,11 @@ def build_connector_meta(
self._request_trackers.pop(finished_req_id, None)
self._unfinished_requests.pop(finished_req_id, None)
self._unfinished_request_ids.discard(finished_req_id)

for req_id in scheduler_output.preempted_req_ids:
self._preempted_req_ids.update(scheduler_output.preempted_req_ids)
self._request_trackers.pop(req_id, None)
self._unfinished_requests.pop(req_id, None)

meta = AscendConnectorMetadata(self._unfinished_request_ids, scheduler_output.preempted_req_ids)

Expand All @@ -170,15 +176,24 @@ def build_connector_meta(
num_tokens_to_compute = (
request.num_computed_tokens +
scheduler_output.num_scheduled_tokens[request.req_id])
request_tracker = RequestTracker.from_new_request(
request, num_tokens_to_compute)
request_tuple = self._unfinished_requests.get(request.req_id)
request_real = request_tuple[0] # type: ignore[index]
if not isinstance(request.block_ids[0], list):
unfolded_block_ids = request.block_ids.copy()
else:
unfolded_block_ids = request.block_ids[0].copy()
request_tracker = RequestTracker(
req_id=request.req_id,
token_len=num_tokens_to_compute,
allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0,
)
self._request_trackers[request.req_id] = request_tracker
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else len(
request.prompt_token_ids))
request_tuple = self._unfinished_requests.get(request.req_id)
request_real = request_tuple[0] # type: ignore[index]

req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
Expand All @@ -195,38 +210,78 @@ def build_connector_meta(
cached_reqs = scheduler_output.scheduled_cached_reqs
if not force_skip_save:
for i, req_id in enumerate(cached_reqs.req_ids):
request_tracker = self._request_trackers[req_id]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_tuple = self._unfinished_requests.get(req_id)
if req_tuple:
request = req_tuple[0]
num_current_tokens = request_tracker.token_len
new_token_ids = request.all_token_ids[
num_current_tokens:num_current_tokens + num_new_tokens]
request_tracker.token_len += len(new_token_ids)
else:
raise ValueError(
f"Request {req_id} is not in _unfinished_requests, "
f"but it is scheduled to be cached")
# resumed request
new_block_ids = cached_reqs.new_block_ids[i]
if not new_block_ids:
continue
request_tracker.update(new_block_ids)

last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else
len(request.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=None,
skip_save=force_skip_save,
block_hashes=request.block_hashes,
is_last_chunk=request_tracker.token_len
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_id in self._preempted_req_ids:
if isinstance(new_block_ids, tuple):
new_block_ids = new_block_ids[0].copy()
else:
new_block_ids = new_block_ids.copy()
self._preempted_req_ids.discard(req_id)
load_spec = self.load_specs.pop(req_id, None)
request_tuple = self._unfinished_requests.get(req_id)
request_real = request_tuple[0] # type: ignore[index]
num_tokens_to_compute = (
request_real.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
request_tracker = RequestTracker(
req_id=req_id,
token_len=num_tokens_to_compute,
allocated_block_ids=new_block_ids,
num_saved_tokens=0,
)
self._request_trackers[req_id] = request_tracker
last_chunk_tokens_num = ((len(request_real.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else len(
request_real.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=force_skip_save,
block_hashes=request_real.block_hashes,
is_last_chunk=request_tracker.token_len
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)

# decode/chunked request
else:
request_tracker = self._request_trackers[req_id]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_tuple = self._unfinished_requests.get(req_id)
if req_tuple:
request = req_tuple[0]
num_current_tokens = request_tracker.token_len
new_token_ids = request.all_token_ids[
num_current_tokens:num_current_tokens + num_new_tokens]
request_tracker.token_len += len(new_token_ids)
else:
raise ValueError(
f"Request {req_id} is not in _unfinished_requests, "
f"but it is scheduled to be cached")
num_computed_token = cached_reqs.num_computed_tokens[i]
if num_computed_token >= len(request.prompt_token_ids):
continue
request_tracker.update(new_block_ids)

last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else
len(request.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=None,
skip_save=force_skip_save,
block_hashes=request.block_hashes,
is_last_chunk=request_tracker.token_len
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)

Expand Down
Loading