Skip to content
Merged
21 changes: 21 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
KV cache helper for store.
"""

from collections.abc import Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal

Expand Down Expand Up @@ -203,6 +204,26 @@ def copy_kv_blocks(
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)


def yield_req_data(
scheduler_output,
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for req_data in scheduler_output.scheduled_new_reqs:
yield req_data.req_id, req_data.block_ids, False

# cached requests
cached_reqs = scheduler_output.scheduled_cached_reqs
yield from zip(
cached_reqs.req_ids,
cached_reqs.new_block_ids,
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
)


@dataclass
class TpKVTopology:
"""
Expand Down
45 changes: 30 additions & 15 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import EngineId, TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId,
TpKVTopology,
yield_req_data,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,
KVConnectorBase_V1,
Expand Down Expand Up @@ -481,7 +485,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_save: dict[ReqId, Request] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()
Expand Down Expand Up @@ -627,16 +631,7 @@ def update_state_after_alloc(
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.

# save all blocks
block_ids = blocks.get_block_ids()[0]
# TODO: skip the blocks that are already in the host xfer buffer.
# Currently, the host xfer buffer block is 1-to-1 mapped to device
# kv blocks, so host blocks won't be flushed as long as its device
# block is not overwritten; and it will be safe to skip saving them
# to host xfer buffer.
if block_ids:
self._reqs_need_save[request.request_id] = (request, block_ids)
self._reqs_need_save[request.request_id] = request
elif params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(
Expand Down Expand Up @@ -688,21 +683,39 @@ def build_connector_meta(
kv_transfer_params=req.kv_transfer_params,
)

for req_id, (req, block_ids) in self._reqs_need_save.items():
# NOTE: For the prefill side, there might be a chance that an early added
# request is a chunked prefill, so we need to check if new blocks are added
for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output):
req_to_save = self._reqs_need_save.get(req_id)
if req_to_save is None or new_block_id_groups is None:
continue
req = req_to_save

assert req.kv_transfer_params is not None
meta.add_new_req_to_save(
request_id=req_id,
local_block_ids=block_ids,
local_block_ids=new_block_id_groups[0],
kv_transfer_params=req.kv_transfer_params,
)
assert scheduler_output.num_scheduled_tokens is not None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
is_partial = (
req.num_computed_tokens + num_scheduled_tokens
) < req.num_prompt_tokens
Comment on lines +702 to +704
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might work...
But I think it's safer to just move self._reqs_need_save.pop(req_id) to request_finished.
@NickLucche WDYT?

Copy link
Copy Markdown
Contributor Author

@xuechendi xuechendi Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._reqs_need_save.clear() was originally in this function, so I prefer to keep self._reqs_need_save.pop(req_id) here instead of changing the logic more?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep let's keep this here this is a tmp buffer, its function terminates here once metadata are built

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickLucche @xuechendi I think you are missing an edge case:

def finish_requests(
self,
request_ids: str | Iterable[str],
finished_status: RequestStatus,
) -> None:
"""Handles the finish signal from outside the scheduler.
For example, the API server can abort a request when the client
disconnects.
"""

When the scheduler finish requests before they finish processing, you will not clear self._reqs_need_save.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which is every time a request is aborted too, nice catch @orozery !

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @orozery ! @NickLucche , I've updated the PR with clean up in request_finished, validated locally

if not is_partial:
# For non-partial prefills, once new req_meta is scheduled, it
# can be removed from _reqs_need_save.
# For partial prefill case, we will retain the request in
# _reqs_need_save until all blocks are scheduled with req_meta.
# Therefore, only pop if `not is_partial`.
self._reqs_need_save.pop(req_id)

meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
meta.reqs_not_processed = self._reqs_not_processed

# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_save.clear()
self._reqs_in_batch = set()
self._reqs_not_processed = set()
self._reqs_need_send = {}
Expand Down Expand Up @@ -748,6 +761,8 @@ def request_finished(
# Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request.
self._reqs_not_processed.add(request.request_id)
# Clear _reqs_need_save if a request is aborted as partial prefill.
self._reqs_need_save.pop(request.request_id, None)
Comment on lines +764 to +765
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should work, but it seems more fragile to me.
I would go further and put this pop right at the beginning.
Then, you could also remove the entire is_partial check from build_connector_meta.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since _reqs_need_save originally is used to buffer which requests should create req_meta for saving to host. So the life cyle is from "scheduled" to "all request metadata are created".

If we go with your proposal, the life cycle becomes from "scheduled" to "request ends"
This changes the original design.

@NickLucche, do you think we should do that?

I assume the fix here is to just handle a corner case when request was aborted ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @orozery is proposing to only clear the id on request finished, so either terminal block was processed or abort/error.
I think it makes sense but I don't necessarily think the current implementation is fragile.

Hence I don't have a strong opinion here, this could also be done in a separate PR, as long as we maximize clarity for these cases.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickLucche @orozery , let's do that in separate PR, since other queues _reqs_to_recv _reqs_need_send and etc are being cleared in build_connector_meta, I would prefer not adding refactor scope into this PR (originally for for accuracy fix). And once this one merged, I can open a new PR and we can have better design there.

return False, None

# TODO: check whether block_ids actually ever be 0. If not we could
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable, Iterator
from collections.abc import Iterable
from dataclasses import dataclass
from itertools import islice
from typing import Any, ClassVar
Expand All @@ -12,6 +12,7 @@
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
Expand Down Expand Up @@ -516,23 +517,3 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
del self._store_jobs[req_id]

return finished_sending, finished_recving


def yield_req_data(
scheduler_output,
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for req_data in scheduler_output.scheduled_new_reqs:
yield req_data.req_id, req_data.block_ids, False

# cached requests
cached_reqs = scheduler_output.scheduled_cached_reqs
yield from zip(
cached_reqs.req_ids,
cached_reqs.new_block_ids,
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
)