diff --git a/tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py b/tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py index 8d2c45f7bd20..3ea935b090df 100644 --- a/tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py +++ b/tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py @@ -232,6 +232,9 @@ def test_request_preemption(request_runner, async_scheduling: bool): expected_stored_gpu_block_indexes=(9, 10, 11), ) + # All stores completed before request_finished -> fence index empty. + assert runner.connector_scheduler._block_id_to_pending_jobs == {} + @pytest.mark.parametrize("async_scheduling", [True, False]) def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling: bool): @@ -292,6 +295,9 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling: # second request will use the GPU prefix cache assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs) + # Fence index drained: stores completed before request_finished ran. + assert runner.connector_scheduler._block_id_to_pending_jobs == {} + @pytest.mark.parametrize("async_scheduling", [True, False]) def test_abort_loading_requests(request_runner, async_scheduling: bool): @@ -534,3 +540,131 @@ def test_do_remote_decode_stores_all_blocks(request_runner, async_scheduling: bo decoded_tokens=[EOS_TOKEN_ID], expected_stored_gpu_block_indexes=(0, 1, 2, 3, 4, 5), ) + # All stores completed before request_finished -> fence index empty. + assert runner.connector_scheduler._block_id_to_pending_jobs == {} + + +# --------------------------------------------------------------------------- +# Tests for the per-job-store-completion design and fence invariants. +# --------------------------------------------------------------------------- + + +def test_loads_do_not_populate_fence_index(request_runner): + """Loads don't populate _block_id_to_pending_jobs (protected by + delay_free_blocks while in flight).""" + runner = request_runner( + offloaded_block_size=12, + gpu_block_size=4, + num_gpu_blocks=100, + async_scheduling=False, + ) + runner.new_request(token_ids=[0] * 12) + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1 + runner.run(decoded_tokens=[], complete_transfers=False) + assert runner.connector_scheduler._block_id_to_pending_jobs == {} + + +def test_fence_at_update_state_after_alloc(request_runner): + """A load reusing a finished request's pending-store block triggers + a flush via update_state_after_alloc's fence. + + num_gpu_blocks=2 forces the BlockPool to give req2 the same block + req1 just freed. + """ + runner = request_runner( + offloaded_block_size=4, + gpu_block_size=4, + num_gpu_blocks=2, + async_scheduling=False, + ) + + runner.new_request(token_ids=[0] * 4) + runner.manager.prepare_store.side_effect = ( + lambda keys, req_context: generate_store_output(keys) + ) + runner.run(decoded_tokens=[EOS_TOKEN_ID], complete_transfers=False) + assert runner.connector_scheduler._block_id_to_pending_jobs + + runner.scheduler.reset_prefix_cache() + runner.new_request(token_ids=[0] * 4) + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1 + runner.manager.prepare_store.side_effect = ( + lambda keys, req_context: generate_store_output([]) + ) + runner.run( + decoded_tokens=[], + complete_transfers=False, + expected_stored_gpu_block_indexes=(0,), + expected_flushed_gpu_block_indexes=(0,), + ) + assert runner.connector_scheduler._block_id_to_pending_jobs == {} + + +def test_fence_at_build_store_jobs(request_runner): + """A new prefill (no load -> update_state_after_alloc returns early) + reusing a finished request's pending-store block is flushed by + _build_store_jobs's fence.""" + runner = request_runner( + offloaded_block_size=4, + gpu_block_size=4, + num_gpu_blocks=2, + async_scheduling=False, + ) + + runner.new_request(token_ids=[0] * 4) + runner.manager.prepare_store.side_effect = ( + lambda keys, req_context: generate_store_output(keys) + ) + runner.run(decoded_tokens=[EOS_TOKEN_ID], complete_transfers=False) + assert runner.connector_scheduler._block_id_to_pending_jobs + + runner.scheduler.reset_prefix_cache() + runner.new_request(token_ids=[1] * 4) + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 0 + runner.manager.prepare_store.side_effect = ( + lambda keys, req_context: generate_store_output([]) + ) + runner.run( + decoded_tokens=[EOS_TOKEN_ID], + expected_stored_gpu_block_indexes=(0,), + expected_flushed_gpu_block_indexes=(0,), + ) + assert runner.connector_scheduler._block_id_to_pending_jobs == {} + + +@pytest.mark.parametrize("async_scheduling", [True, False]) +def test_complete_store_called_per_job(request_runner, async_scheduling: bool): + """complete_store fires per-job, not deferred to request finish. + Each call carries only that store's keys.""" + offloaded_block_size = 12 + runner = request_runner( + offloaded_block_size=offloaded_block_size, + gpu_block_size=4, + num_gpu_blocks=100, + async_scheduling=async_scheduling, + ) + runner.new_request(token_ids=[0] * offloaded_block_size) + runner.manager.prepare_store.side_effect = ( + lambda keys, req_context: generate_store_output(keys) + ) + + # First store: fires when block 0 is fully populated. + runner.run(decoded_tokens=[0, 0], expected_stored_gpu_block_indexes=(0, 1, 2)) + assert runner.manager.complete_store.call_count == 1 + first_call_keys = set(runner.manager.complete_store.call_args.args[0]) + assert len(first_call_keys) == 1 + runner.manager.complete_store.reset_mock() + + # Second store: fires when block 1 is fully populated, with different keys. + runner.run( + decoded_tokens=[0] * (offloaded_block_size + 1), + expected_stored_gpu_block_indexes=(3, 4, 5), + ) + assert runner.manager.complete_store.call_count == 1 + second_call_keys = set(runner.manager.complete_store.call_args.args[0]) + assert first_call_keys != second_call_keys + runner.manager.complete_store.reset_mock() + + # Finish: no store pending -> no further call. + runner.run(decoded_tokens=[EOS_TOKEN_ID]) + assert runner.manager.complete_store.call_count == 0 diff --git a/tests/v1/kv_connector/unit/offloading_connector/test_worker_metadata.py b/tests/v1/kv_connector/unit/offloading_connector/test_worker_metadata.py new file mode 100644 index 000000000000..ab9d676cb4ae --- /dev/null +++ b/tests/v1/kv_connector/unit/offloading_connector/test_worker_metadata.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import ( + OffloadingWorkerMetadata, +) + +pytestmark = pytest.mark.cpu_test + + +def test_aggregate_sums_counts(): + meta1 = OffloadingWorkerMetadata(completed_jobs={42: 1, 7: 1}) + meta2 = OffloadingWorkerMetadata(completed_jobs={42: 1, 7: 1}) + result = meta1.aggregate(meta2) + assert result.completed_jobs == {42: 2, 7: 2} + + +def test_aggregate_disjoint_jobs(): + meta1 = OffloadingWorkerMetadata(completed_jobs={42: 1, 7: 1}) + meta2 = OffloadingWorkerMetadata(completed_jobs={43: 1, 8: 1}) + result = meta1.aggregate(meta2) + assert result.completed_jobs == {42: 1, 7: 1, 43: 1, 8: 1} + + +def test_aggregate_multiple_workers(): + meta1 = OffloadingWorkerMetadata(completed_jobs={42: 1, 43: 1, 7: 1}) + meta2 = OffloadingWorkerMetadata(completed_jobs={42: 1, 7: 1, 8: 1}) + meta3 = OffloadingWorkerMetadata(completed_jobs={42: 1, 43: 1, 8: 1}) + result = meta1.aggregate(meta2).aggregate(meta3) + assert result.completed_jobs == {42: 3, 43: 2, 7: 2, 8: 2} diff --git a/tests/v1/kv_connector/unit/offloading_connector/utils.py b/tests/v1/kv_connector/unit/offloading_connector/utils.py index 60dc11f4ca4b..d5adcd3f7724 100644 --- a/tests/v1/kv_connector/unit/offloading_connector/utils.py +++ b/tests/v1/kv_connector/unit/offloading_connector/utils.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import copy from collections.abc import Iterable, Iterator from dataclasses import dataclass from typing import Any @@ -19,6 +18,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import ( OffloadingConnectorMetadata, + OffloadingWorkerMetadata, ) from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import ( OffloadingConnector, @@ -51,7 +51,6 @@ TransferResult, TransferSpec, ) -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager @@ -369,7 +368,12 @@ def _run(self, decoded_tokens: list[int], complete_transfers: bool): prev_scheduler_output = None prev_model_runner_output = None while True: - assert self.scheduler.requests + # Strict-always-False frees the request immediately on EOS, but + # the worker may still have a deferred store queued. In production + # the next request's step drains it; in single-request tests we + # must keep stepping until the scheduler sees no in-flight jobs. + if not self.scheduler.requests and not self.connector_scheduler._jobs: + break scheduler_output = self.scheduler.schedule() self._update_gpu_block_idx() @@ -392,6 +396,10 @@ def _run(self, decoded_tokens: list[int], complete_transfers: bool): finished_sending, finished_recving = self.worker_connector.get_finished( scheduler_output.finished_req_ids ) + worker_meta = ( + self.worker_connector.build_connector_worker_meta() + or OffloadingWorkerMetadata() + ) self.worker_connector.clear_connector_metadata() @@ -400,6 +408,7 @@ def _run(self, decoded_tokens: list[int], complete_transfers: bool): finished_sending=finished_sending, finished_recving=finished_recving, token_id=token_id or 0, + kv_connector_worker_meta=worker_meta, ) prev_token_id = token_id @@ -420,7 +429,7 @@ def _run(self, decoded_tokens: list[int], complete_transfers: bool): if ( prev_token_id == EOS_TOKEN_ID and prev_token_id != token_id - and self.scheduler.requests + and (self.scheduler.requests or self.connector_scheduler._jobs) ): # continue for one more step to allow offloading to kick off continue @@ -435,26 +444,9 @@ def _run(self, decoded_tokens: list[int], complete_transfers: bool): self._parse_transfers() - # run one more step to update finished stored if EOS_TOKEN_ID in decoded_tokens: assert not self.scheduler.running - while self.scheduler.requests: - scheduler_output = self.scheduler.schedule() - - finished_sending, finished_recving = self.worker_connector.get_finished( - scheduler_output.finished_req_ids - ) - - assert not finished_recving - - model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) - model_runner_output.kv_connector_output = KVConnectorOutput( - finished_sending=finished_sending - ) - - self.scheduler.update_from_output(scheduler_output, model_runner_output) - def run( self, decoded_tokens: list[int], diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 5f0036807b0c..0710ffa63a81 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -24,6 +24,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + KVConnectorWorkerMetadata, ) from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa ExampleConnector, @@ -249,6 +250,7 @@ def create_model_runner_output( invalid_block_ids: set[int] | None = None, use_eos: bool = False, token_id: int = 0, + kv_connector_worker_meta: KVConnectorWorkerMetadata | None = None, ) -> ModelRunnerOutput: """Make dummy model runner output for testing.""" @@ -266,11 +268,13 @@ def create_model_runner_output( finished_sending is None and finished_recving is None and invalid_block_ids is None + and kv_connector_worker_meta is None ) else KVConnectorOutput( finished_sending=finished_sending, finished_recving=finished_recving, invalid_block_ids=invalid_block_ids or set(), + kv_connector_worker_meta=kv_connector_worker_meta, ) ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/common.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/common.py index 06a727a27b55..c5a251a2a515 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/common.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/common.py @@ -1,15 +1,60 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass +from dataclasses import dataclass, field -from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, + KVConnectorWorkerMetadata, +) from vllm.v1.kv_offload.worker.worker import TransferSpec ReqId = str +@dataclass +class TransferJob: + """A transfer job bundling request context with transfer spec. + + Used for both loads and stores, keyed by scheduler-assigned job ID. + The worker reports the job ID back when the transfer finishes, + and the scheduler processes the completion. + """ + + req_id: ReqId + transfer_spec: TransferSpec + + @dataclass class OffloadingConnectorMetadata(KVConnectorMetadata): - reqs_to_load: dict[ReqId, TransferSpec] - reqs_to_store: dict[ReqId, TransferSpec] - reqs_to_flush: set[str] | None = None + # Keyed by scheduler-assigned job IDs. + load_jobs: dict[int, TransferJob] + store_jobs: dict[int, TransferJob] + jobs_to_flush: set[int] | None = None + + +@dataclass +class OffloadingWorkerMetadata(KVConnectorWorkerMetadata): + """Worker -> Scheduler metadata for completed transfer jobs. + + Each worker reports {job_id: 1} for newly completed transfer jobs + (load or store). aggregate() sums counts across workers within a step. + The scheduler accumulates across steps and processes + a transfer completion only when count reaches num_workers. + """ + + completed_jobs: dict[int, int] = field(default_factory=dict) + + def mark_completed(self, job_id: int) -> None: + """Record a transfer job completion from this worker.""" + self.completed_jobs[job_id] = 1 + + def aggregate( + self, other: "KVConnectorWorkerMetadata" + ) -> "KVConnectorWorkerMetadata": + assert isinstance(other, OffloadingWorkerMetadata) + + merged = dict(self.completed_jobs) + for job_id, v in other.completed_jobs.items(): + merged[job_id] = merged.get(job_id, 0) + v + + return OffloadingWorkerMetadata(completed_jobs=merged) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py index 1ef99eaa4461..cb8af41cdd69 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from itertools import islice @@ -11,7 +10,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import ( OffloadingConnectorMetadata, + OffloadingWorkerMetadata, ReqId, + TransferJob, ) from vllm.logger import init_logger from vllm.utils.math_utils import cdiv @@ -26,13 +27,27 @@ ) from vllm.v1.kv_offload.mediums import GPULoadStoreSpec from vllm.v1.kv_offload.spec import OffloadingSpec -from vllm.v1.kv_offload.worker.worker import TransferSpec from vllm.v1.outputs import KVConnectorOutput from vllm.v1.request import Request logger = init_logger(__name__) +@dataclass(slots=True) +class TransferJobStatus: + """Tracks scheduler-side state for a single transfer job.""" + + req_id: ReqId + # Number of workers still pending. Starts at num_workers, + # decremented as each worker reports completion. Job is done at 0. + pending_count: int + # Offload keys this job covers; passed to manager.complete_*(). + keys: set[OffloadKey] + is_store: bool + # GPU blocks the fence tracks. Store src blocks; None for loads. + gpu_block_ids: list[int] | None = None + + class GroupOffloadConfig(NamedTuple): group_idx: int gpu_block_size: int @@ -43,10 +58,12 @@ class GroupOffloadConfig(NamedTuple): class SchedulerOffloadConfig(NamedTuple): kv_group_configs: tuple[GroupOffloadConfig, ...] block_size_factor: int + num_workers: int @classmethod def from_spec(cls, spec: OffloadingSpec) -> "SchedulerOffloadConfig": return cls( + num_workers=spec.vllm_config.parallel_config.world_size, kv_group_configs=tuple( GroupOffloadConfig( group_idx=idx, @@ -79,6 +96,9 @@ class RequestOffloadState: req_context: ReqContext = field(init=False) # number of hits in the GPU cache num_locally_computed_tokens: int = 0 + # In-flight job IDs. Per the connector's invariant, at any given time + # this contains either a single load job, or one or more store jobs. + transfer_jobs: set[int] = field(default_factory=set) def __post_init__(self) -> None: self.group_states = tuple( @@ -135,17 +155,26 @@ def __init__(self, spec: OffloadingSpec): self.lookup_groups = attention_groups self._req_status: dict[ReqId, RequestOffloadState] = {} - # requests to load for the current scheduler step - self._reqs_to_load: dict[ReqId, TransferSpec] = {} + self._current_batch_load_jobs: dict[int, TransferJob] = {} + self._current_batch_jobs_to_flush: set[int] = set() # if GPU prefix caching is enabled, # track loaded blocks to avoid redundant loads self._blocks_being_loaded: set[OffloadKey] | None = ( set() if spec.vllm_config.cache_config.enable_prefix_caching else None ) - # request ID -> set(offload keys being stored/loaded) - self._reqs_being_stored = defaultdict[ReqId, set[OffloadKey]](set) - self._reqs_being_loaded = defaultdict[ReqId, set[OffloadKey]](set) + # Job ID counter shared by loads and stores. + self._job_counter: int = 0 + self._jobs: dict[int, TransferJobStatus] = {} + + # block_id -> pending store job_ids. Populated only for finished + # requests (running-request blocks are protected by their ref_cnt). + self._block_id_to_pending_jobs: dict[int, set[int]] = {} + + def _generate_job_id(self) -> int: + job_id = self._job_counter + self._job_counter += 1 + return job_id def _maximal_prefix_lookup( self, keys: Iterable[OffloadKey], req_context: ReqContext @@ -369,23 +398,46 @@ def update_state_after_alloc( # entire KV cache so a remote decode node can consume it. group_state.next_stored_block_idx = num_blocks + # Fence dst blocks against finished-request pending stores. + if ( + self._block_id_to_pending_jobs + and not self._block_id_to_pending_jobs.keys().isdisjoint(dst_block_ids) + ): + self._current_batch_jobs_to_flush.update( + jid + for bid in dst_block_ids + for jid in self._block_id_to_pending_jobs.get(bid, ()) + ) + src_spec = self.manager.prepare_load(keys_to_load, req_status.req_context) dst_spec = GPULoadStoreSpec( dst_block_ids, group_sizes=group_sizes, block_indices=block_indices ) - self._reqs_to_load[request.request_id] = (src_spec, dst_spec) - req_blocks_being_loaded = self._reqs_being_loaded[request.request_id] - req_blocks_being_loaded.update(keys_to_load) + load_job_id = self._generate_job_id() + self._current_batch_load_jobs[load_job_id] = TransferJob( + req_id=request.request_id, + transfer_spec=(src_spec, dst_spec), + ) + # a load can only be issued when no other jobs are pending. + assert not req_status.transfer_jobs + req_status.transfer_jobs.add(load_job_id) + self._jobs[load_job_id] = TransferJobStatus( + req_id=request.request_id, + pending_count=self.config.num_workers, + keys=set(keys_to_load), + is_store=False, + ) if self._blocks_being_loaded is not None: - self._blocks_being_loaded.update(req_blocks_being_loaded) + self._blocks_being_loaded.update(keys_to_load) - def _get_reqs_to_store( - self, scheduler_output: SchedulerOutput - ) -> dict[ReqId, TransferSpec]: + def _build_store_jobs( + self, + scheduler_output: SchedulerOutput, + ) -> dict[int, TransferJob]: block_size_factor = self.config.block_size_factor - reqs_to_store: dict[ReqId, TransferSpec] = {} + store_jobs: dict[int, TransferJob] = {} # iterate over both new and cached requests for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output): req_status = self._req_status[req_id] @@ -398,6 +450,19 @@ def _get_reqs_to_store( if new_block_id_groups: req_status.update_block_id_groups(new_block_id_groups) + # Fence new blocks against in-flight stores. + if self._block_id_to_pending_jobs: + new_blocks_flat = [ + bid for new_blocks in new_block_id_groups for bid in new_blocks + ] + if not self._block_id_to_pending_jobs.keys().isdisjoint( + new_blocks_flat + ): + self._current_batch_jobs_to_flush.update( + jid + for bid in new_blocks_flat + for jid in self._block_id_to_pending_jobs.get(bid, ()) + ) num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_tokens_after_batch = req.num_computed_tokens + num_scheduled_tokens @@ -491,36 +556,52 @@ def _get_reqs_to_store( ) dst_spec = store_output.store_spec - reqs_to_store[req_id] = (src_spec, dst_spec) - self._reqs_being_stored[req_id] |= keys_to_store + job_id = self._generate_job_id() + # a store can only be issued when no load is pending. + if req_status.transfer_jobs: + any_jid = next(iter(req_status.transfer_jobs)) + assert self._jobs[any_jid].is_store + req_status.transfer_jobs.add(job_id) + self._jobs[job_id] = TransferJobStatus( + req_id=req_id, + pending_count=self.config.num_workers, + keys=set(keys_to_store), + is_store=True, + gpu_block_ids=src_block_ids, + ) + + store_jobs[job_id] = TransferJob( + req_id=req_id, transfer_spec=(src_spec, dst_spec) + ) logger.debug( - "Request %s offloading %s blocks upto %d tokens", + "Request %s offloading %s blocks upto %d tokens (job %d)", req_id, len(keys_to_store), num_offloadable_tokens, + job_id, ) - return reqs_to_store + return store_jobs def build_connector_meta( self, scheduler_output: SchedulerOutput ) -> KVConnectorMetadata: - meta = OffloadingConnectorMetadata( - reqs_to_load=self._reqs_to_load, - reqs_to_store=self._get_reqs_to_store(scheduler_output), - reqs_to_flush=scheduler_output.preempted_req_ids, - ) - self._reqs_to_load = {} - - # NOTE (orozery): we should move this logic to update_connector_output - # once KVConnectorOutput allows us to report completed transfers for req_id in scheduler_output.preempted_req_ids or (): - keys = self._reqs_being_stored.get(req_id) - if keys: - self.manager.complete_store(keys) - keys.clear() + req_status = self._req_status.get(req_id) + if req_status is None or not req_status.transfer_jobs: + continue + any_jid = next(iter(req_status.transfer_jobs)) + assert self._jobs[any_jid].is_store + self._current_batch_jobs_to_flush.update(req_status.transfer_jobs) + meta = OffloadingConnectorMetadata( + load_jobs=self._current_batch_load_jobs, + store_jobs=self._build_store_jobs(scheduler_output), + jobs_to_flush=self._current_batch_jobs_to_flush, + ) + self._current_batch_load_jobs = {} + self._current_batch_jobs_to_flush = set() return meta def update_connector_output(self, connector_output: KVConnectorOutput): @@ -531,17 +612,37 @@ def update_connector_output(self, connector_output: KVConnectorOutput): connector_output (KVConnectorOutput): the worker-side connectors output. """ - for req_id in connector_output.finished_sending or []: - keys = self._reqs_being_stored.pop(req_id, None) - if keys: - self.manager.complete_store(keys) - - for req_id in connector_output.finished_recving or []: - keys = self._reqs_being_loaded.pop(req_id, None) - if keys: + meta = connector_output.kv_connector_worker_meta + if not isinstance(meta, OffloadingWorkerMetadata): + assert meta is None + meta = OffloadingWorkerMetadata() + for job_id, count in meta.completed_jobs.items(): + assert count > 0 + job_status = self._jobs[job_id] + job_status.pending_count -= count + if job_status.pending_count > 0: + continue + assert job_status.pending_count == 0 + + if job_status.is_store: + self.manager.complete_store(job_status.keys) + else: + self.manager.complete_load(job_status.keys) if self._blocks_being_loaded: - self._blocks_being_loaded.difference_update(keys) - self.manager.complete_load(keys) + self._blocks_being_loaded.difference_update(job_status.keys) + + req_status = self._req_status[job_status.req_id] + if self._block_id_to_pending_jobs and req_status.req.is_finished(): + for bid in job_status.gpu_block_ids or (): + pending = self._block_id_to_pending_jobs[bid] + pending.remove(job_id) + if not pending: + del self._block_id_to_pending_jobs[bid] + + del self._jobs[job_id] + req_status.transfer_jobs.remove(job_id) + if not req_status.transfer_jobs and req_status.req.is_finished(): + del self._req_status[job_status.req_id] def request_finished( self, @@ -558,14 +659,21 @@ def request_finished( Optional KVTransferParams to be included in the request outputs returned by the engine. """ - req_id = request.request_id - # TODO(orozery): possibly kickoff offload for last block # which may have been deferred due to async scheduling - self._req_status.pop(req_id, None) - - request_being_stored = req_id in self._reqs_being_stored - return request_being_stored, None + req_status = self._req_status.get(request.request_id) + if req_status is None: + return False, None + if not req_status.transfer_jobs: + del self._req_status[request.request_id] + return False, None + # Pending stores will outlive the request's block ownership. + # Register them so future block reuse triggers a flush. + for job_id in req_status.transfer_jobs: + job_status = self._jobs[job_id] + for bid in job_status.gpu_block_ids or (): + self._block_id_to_pending_jobs.setdefault(bid, set()).add(job_id) + return False, None def take_events(self) -> Iterable[KVCacheEvent]: """Take the KV cache events from the connector. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py index cc6d8262c7e6..78547d569df3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py @@ -11,6 +11,7 @@ ) from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import ( OffloadingConnectorMetadata, + OffloadingWorkerMetadata, ReqId, ) from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import ( @@ -45,24 +46,11 @@ def __init__(self, spec: OffloadingSpec): self.spec = spec self.worker = OffloadingWorker() - self._job_counter = 0 - self.kv_connector_stats = OffloadingConnectorStats() - # req_id -> (job_id, store) - self._jobs: dict[int, tuple[ReqId, bool]] = {} - # req_id -> active job IDs - self._load_job: dict[ReqId, int] = {} - # req_id -> set(active job IDs) - self._store_jobs = defaultdict[ReqId, set[int]](set) - # list of store jobs pending submission (job_id, transfer_spec) + # job_id -> req_id for in-flight loads. + self._load_jobs: dict[int, ReqId] = {} self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = [] - - self._finished_reqs_waiting_for_store: set[ReqId] = set() - - def _generate_job_id(self) -> int: - job_id = self._job_counter - self._job_counter = job_id + 1 - return job_id + self._connector_worker_meta = OffloadingWorkerMetadata() def _register_handlers(self, kv_caches: CanonicalKVCaches): for src_cls, dst_cls, handler in self.spec.get_handlers(kv_caches): @@ -301,10 +289,8 @@ def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata) assert success self._unsubmitted_store_jobs.clear() - for req_id in kv_connector_metadata.reqs_to_flush or (): - job_ids = self._store_jobs.get(req_id) - if job_ids: - self.worker.wait(job_ids) + if kv_connector_metadata.jobs_to_flush: + self.worker.wait(kv_connector_metadata.jobs_to_flush) def start_kv_transfers(self, metadata: OffloadingConnectorMetadata): for job_id, transfer_spec in self._unsubmitted_store_jobs: @@ -312,41 +298,33 @@ def start_kv_transfers(self, metadata: OffloadingConnectorMetadata): assert success self._unsubmitted_store_jobs.clear() - for req_id, transfer_spec in metadata.reqs_to_load.items(): - job_id = self._generate_job_id() - self._jobs[job_id] = (req_id, False) - assert req_id not in self._load_job - self._load_job[req_id] = job_id - success = self.worker.transfer_async(job_id, transfer_spec) + for job_id, entry in metadata.load_jobs.items(): + self._load_jobs[job_id] = entry.req_id + success = self.worker.transfer_async(job_id, entry.transfer_spec) assert success def prepare_store_kv(self, metadata: OffloadingConnectorMetadata): - for req_id, transfer_spec in metadata.reqs_to_store.items(): - job_id = self._generate_job_id() - self._jobs[job_id] = (req_id, True) - self._store_jobs[req_id].add(job_id) - # NOTE(orozery): defer the store to the beginning of the next engine step, - # so that offloading starts AFTER transfers related to token sampling, - # thereby avoiding delays to token generation due to offloading. - self._unsubmitted_store_jobs.append((job_id, transfer_spec)) + for job_id, entry in metadata.store_jobs.items(): + # NOTE(orozery): defer the store to the beginning of the next + # engine step, so that offloading starts AFTER transfers related + # to token sampling, thereby avoiding delays to token generation. + self._unsubmitted_store_jobs.append((job_id, entry.transfer_spec)) def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """ - Notifies worker-side connector ids of requests that have - finished generating tokens. - Returns a list of request IDs that finished loading or storing. - Returns: - ids of requests that have finished asynchronous transfer - tuple of (sending/saving ids, recving/loading ids). + tuple of (finished_sending, finished_recving). Stores never + emit finished_sending — the scheduler tracks store completion + via kv_connector_worker_meta.completed_jobs and fences any + block reuse via jobs_to_flush. Loads still emit + finished_recving so the base scheduler can resume requests + blocked on remote KV (and free aborted-during-load reqs). """ - finished_sending = set() - finished_recving = set() + finished_recving: set[str] = set() for transfer_result in self.worker.get_finished(): # we currently do not support job failures job_id = transfer_result.job_id assert transfer_result.success - req_id, store = self._jobs.pop(job_id) if ( transfer_result.transfer_time and transfer_result.transfer_size is not None @@ -357,31 +335,21 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: time=transfer_result.transfer_time, transfer_type=transfer_result.transfer_type, ) - if store: - req_jobs = self._store_jobs[req_id] - req_jobs.remove(job_id) - if req_jobs: - continue - - if req_id in self._finished_reqs_waiting_for_store: - self._finished_reqs_waiting_for_store.remove(req_id) - finished_sending.add(req_id) - del self._store_jobs[req_id] - else: - req_job = self._load_job[req_id] - assert job_id == req_job - del self._load_job[req_id] + + self._connector_worker_meta.mark_completed(job_id) + req_id = self._load_jobs.pop(job_id, None) + if req_id is not None: finished_recving.add(req_id) - for req_id in finished_req_ids: - pending_req_jobs = self._store_jobs.get(req_id) - if pending_req_jobs: - self._finished_reqs_waiting_for_store.add(req_id) - elif pending_req_jobs is not None: - finished_sending.add(req_id) - del self._store_jobs[req_id] + return set(), finished_recving - return finished_sending, finished_recving + def build_connector_worker_meta(self) -> OffloadingWorkerMetadata | None: + """Return completed transfer job IDs since the last call.""" + if not self._connector_worker_meta.completed_jobs: + return None + meta = self._connector_worker_meta + self._connector_worker_meta = OffloadingWorkerMetadata() + return meta def get_kv_connector_stats(self) -> KVConnectorStats | None: """ @@ -396,11 +364,7 @@ def get_kv_connector_stats(self) -> KVConnectorStats | None: return kv_connector_stats def shutdown(self) -> None: - # Drop deferred store jobs: there is no point in submitting - # them during shutdown. self._unsubmitted_store_jobs.clear() - self._jobs.clear() - self._load_job.clear() - self._store_jobs.clear() - self._finished_reqs_waiting_for_store.clear() + self._load_jobs.clear() + self._connector_worker_meta = OffloadingWorkerMetadata() self.worker.shutdown() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index f11281dcf14e..05b835572c9f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -20,6 +20,7 @@ ) from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import ( OffloadingConnectorMetadata, + OffloadingWorkerMetadata, ) from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import ( OffloadingConnectorStats, @@ -111,6 +112,11 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: assert self.connector_worker is not None return self.connector_worker.get_finished(finished_req_ids) + def build_connector_worker_meta(self) -> OffloadingWorkerMetadata | None: + if self.connector_worker is not None: + return self.connector_worker.build_connector_worker_meta() + return None + def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int ) -> tuple[int | None, bool]: