Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tests/v1/kv_connector/unit/test_multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,11 @@ def test_multi_example_connector_consistency():
]
# First three events are from initialization (register_kv_caches,
# set_host_xfer_buffer_ops, get_handshake_metadata), then generate() events.
assert events["storage1-WORKER"][:7] == [
assert events["storage1-WORKER"][:8] == [
"register_kv_caches",
"set_host_xfer_buffer_ops",
"get_handshake_metadata",
"handle_preemptions",
"bind_connector_metadata",
"start_load_kv",
"wait_for_layer_load",
Expand All @@ -246,10 +247,11 @@ def test_multi_example_connector_consistency():
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
]
assert events["storage2-WORKER"][:7] == [
assert events["storage2-WORKER"][:8] == [
"register_kv_caches",
"set_host_xfer_buffer_ops",
"get_handshake_metadata",
"handle_preemptions",
"bind_connector_metadata",
"start_load_kv",
"wait_for_layer_load",
Expand Down Expand Up @@ -399,8 +401,8 @@ def test_multi_connector_handle_preemptions_integration():
# testing the delegation behavior of MultiConnector here.
# The connector attribute contains the KV connector.
assert scheduler.connector is not None, "Scheduler should have a connector"
preempted_req_ids = {"req-1", "req-2", "req-3"}
scheduler.connector.handle_preemptions(preempted_req_ids)
connector_md = scheduler.connector.build_connector_meta(scheduler.schedule())
scheduler.connector.handle_preemptions(connector_md)

# Verify both connectors received the handle_preemptions call
events = get_connector_events()
Expand Down
5 changes: 1 addition & 4 deletions tests/v1/kv_connector/unit/test_offloading_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,7 @@ def _run(self, decoded_tokens: list[int], complete_transfers: bool):
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)

if scheduler_output.preempted_req_ids:
self.worker_connector.handle_preemptions(
scheduler_output.preempted_req_ids
)
self.worker_connector.handle_preemptions(kv_connector_metadata)

self.worker_connector.bind_connector_metadata(kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx)
Expand Down
8 changes: 4 additions & 4 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
handle_preemptions() - called if there are preempted requests,
before their blocks are overwritten
handle_preemptions() - called for handling preempted requests
or request evicted blocks before they are overwritten

start_load_kv() - starts loading all KVs (maybe async)
wait_for_layer_load() - blocks until layer i load is done
Expand Down Expand Up @@ -288,9 +288,9 @@ def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
"""
return

def handle_preemptions(self, preempted_req_ids: set[str]):
def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata):
"""
Handle preempted requests BEFORE their blocks are overwritten.
Handle preempted requests or evicted blocks BEFORE they are overwritten.
Needed for connectors which use async saves (e.g., OffloadingConnector)
"""
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,11 @@ def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
for c in self._connectors:
c.set_host_xfer_buffer_ops(copy_operation)

def handle_preemptions(self, preempted_req_ids: set[str]):
def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata):
"""Handle preempted requests for all sub-connectors."""
for c in self._connectors:
c.handle_preemptions(preempted_req_ids)
assert isinstance(kv_connector_metadata, MultiKVConnectorMetadata)
for c, cm in zip(self._connectors, kv_connector_metadata.metadata):
c.handle_preemptions(cm)

def get_finished_count(self) -> int | None:
# TODO(https://github.com/vllm-project/vllm/issues/33400)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def record_transfer(self, num_bytes: int, time: float, transfer_type: TransferTy
class OffloadingConnectorMetadata(KVConnectorMetadata):
reqs_to_load: dict[ReqId, TransferSpec]
reqs_to_store: dict[ReqId, TransferSpec]
reqs_to_flush: set[str] | None = None


class OffloadingConnector(KVConnectorBase_V1):
Expand Down Expand Up @@ -146,9 +147,10 @@ def register_cross_layers_kv_cache(
assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)

def handle_preemptions(self, preempted_req_ids: set[str]):
def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata):
assert self.connector_worker is not None
self.connector_worker.handle_preemptions(preempted_req_ids)
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
self.connector_worker.handle_preemptions(kv_connector_metadata)

def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
Expand Down Expand Up @@ -482,6 +484,7 @@ def build_connector_meta(
meta = OffloadingConnectorMetadata(
reqs_to_load=self._reqs_to_load,
reqs_to_store=self._get_reqs_to_store(scheduler_output),
reqs_to_flush=scheduler_output.preempted_req_ids,
)
self._reqs_to_load = {}

Expand Down Expand Up @@ -619,13 +622,13 @@ def register_cross_layers_kv_cache(
attn_backends = {cross_layer_name: attn_backend}
self._register_handlers(kv_caches, attn_backends)

def handle_preemptions(self, preempted_req_ids: set[str]):
def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be documented that this is a breaking API change?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fresh API that I recently introduced. It's not used by any in-tree connector, and most-likely not used at all.
I don't see any benefit from documenting it.

for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec)
assert success
self._unsubmitted_store_jobs.clear()

for req_id in preempted_req_ids:
for req_id in kv_connector_metadata.reqs_to_flush or ():
job_ids = self._store_jobs.get(req_id)
if job_ids:
self.worker.wait(job_ids)
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/worker/gpu/kv_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,10 @@ def pre_forward(self, scheduler_output: "SchedulerOutput") -> None:
if self._disabled:
return

if scheduler_output.preempted_req_ids:
self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids)
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
self.kv_connector.bind_connector_metadata(kv_connector_metadata)
self.kv_connector.handle_preemptions(kv_connector_metadata)

# TODO: sort out KV Connectors' use of forward_context
if is_forward_context_available():
Expand Down
8 changes: 4 additions & 4 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3594,10 +3594,10 @@ def execute_model(
scheduled_spec_decode_tokens=spec_decode_tokens_copy,
)

if scheduler_output.preempted_req_ids and has_kv_transfer_group():
get_kv_transfer_group().handle_preemptions(
scheduler_output.preempted_req_ids
)
if has_kv_transfer_group():
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this assert here? Previously, it checked for ids were available.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, previously handle_preemptions only got a a set of preempted requests IDs, whereas now it gets KVConnectorMetadata.
has_kv_transfer_group() guarantees that kv_connector_metadata is not None (since build_connector_metadata is called on each step), so that assert is fine.
BTW the same assert also exists in the forward pass (at KVConnectorModelRunnerMixin._get_kv_connector_output).

get_kv_transfer_group().handle_preemptions(kv_connector_metadata)
Comment on lines +3597 to +3600
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The handle_preemptions method is called here, but it's also called within ActiveKVConnector.pre_forward, which is invoked later in this execute_model function (via set_forward_context or kv_connector_no_forward). This results in handle_preemptions being called twice in each step.

While the current implementations appear to be idempotent, this redundancy can be confusing and might lead to bugs if a future connector's handle_preemptions is not idempotent. To centralize the logic, this call should be removed, relying on the one inside ActiveKVConnector.pre_forward.

Copy link
Copy Markdown
Collaborator Author

@orozery orozery Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK pre_forward is only called in model runner v2.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@orozery Does the code make this obvious or enforced? If it possible that it could be called twice?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All connector functions have 2 call locations, one for each model runner.
For a specific run, only one model runner will be used (either v1 or v2), so it's not possible functions will be called twice.


num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with (
Expand Down
Loading