From 3d6d95350947c9eca4a208f5d33f0a67625cbde1 Mon Sep 17 00:00:00 2001 From: Martin Hickey Date: Tue, 21 Apr 2026 18:38:02 +0100 Subject: [PATCH 1/4] Refactor KV Offloading Signed-off-by: Martin Hickey --- vllm/v1/kv_offload/base.py | 12 +++ vllm/v1/kv_offload/cpu/manager.py | 104 ++++++++++++++++++++++++ vllm/v1/kv_offload/cpu/spec.py | 1 + vllm/v1/kv_offload/reuse_manager.py | 118 ---------------------------- 4 files changed, 117 insertions(+), 118 deletions(-) delete mode 100644 vllm/v1/kv_offload/reuse_manager.py diff --git a/vllm/v1/kv_offload/base.py b/vllm/v1/kv_offload/base.py index b30c5d066acc..b7320cf2bf09 100644 --- a/vllm/v1/kv_offload/base.py +++ b/vllm/v1/kv_offload/base.py @@ -128,7 +128,11 @@ def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None: @abstractmethod def prepare_load( self, +<<<<<<< HEAD keys: Sequence[OffloadKey], +======= + keys: Iterable[OffloadKey], +>>>>>>> 83731c08b (Refactor KV Offloading) req_context: ReqContext, ) -> LoadStoreSpec: """ @@ -147,7 +151,11 @@ def prepare_load( """ pass +<<<<<<< HEAD def touch(self, keys: Sequence[OffloadKey]): +======= + def touch(self, keys: Iterable[OffloadKey]): +>>>>>>> 83731c08b (Refactor KV Offloading) """ Mark the given blocks as recently used. This could in practice mean moving them to the end of an LRU list. @@ -169,7 +177,11 @@ def complete_load(self, keys: Iterable[OffloadKey]): @abstractmethod def prepare_store( self, +<<<<<<< HEAD keys: Sequence[OffloadKey], +======= + keys: Iterable[OffloadKey], +>>>>>>> 83731c08b (Refactor KV Offloading) req_context: ReqContext, ) -> PrepareStoreOutput | None: """ diff --git a/vllm/v1/kv_offload/cpu/manager.py b/vllm/v1/kv_offload/cpu/manager.py index 3527773a2004..76613444ad33 100644 --- a/vllm/v1/kv_offload/cpu/manager.py +++ b/vllm/v1/kv_offload/cpu/manager.py @@ -196,3 +196,107 @@ def take_events(self) -> Iterable[OffloadingEvent]: if self.events is not None: yield from self.events self.events.clear() + + +# ----------------------------------------------------------------------------- +# FilterReusedOffloadingManager — reuse-frequency gating for CPU offload stores +# ----------------------------------------------------------------------------- + + +class FilterReusedOffloadingManager(OffloadingManager): + """An :class:`OffloadingManager` decorator that skips storing blocks + whose reuse frequency is below *store_threshold*. + + All methods are delegated to the *backing* manager. Two methods are + intercepted: + + * ``prepare_store`` — filters out keys that have not yet + * ``lookup`` — records the visited key in an internal LRU + counter, then delegates to the backing manager. + crossed the threshold *before* calling the backing + ``prepare_store``. + + Args: + backing: The underlying ``OffloadingManager`` to delegate to. + store_threshold: A block must be seen at least this many times in + ``lookup()`` before it is eligible for offloading. Must be >= 2 + (a value of 1 would be equivalent to no filtering). + max_tracker_size: Maximum entries in the internal tracker's LRU table. + """ + + def __init__( + self, + backing: OffloadingManager, + store_threshold: int = 2, + max_tracker_size: int = 64_000, + ): + if store_threshold < 2: + raise ValueError( + "FilterReusedOffloadingManager store_threshold must be >= 2, " + f"got {store_threshold}" + ) + if max_tracker_size < 1: + raise ValueError( + "FilterReusedOffloadingManager max_tracker_size must be >= 1, " + f"got {max_tracker_size}" + ) + self._backing = backing + self.store_threshold = store_threshold + self.max_tracker_size = max_tracker_size + # Ordered so we can evict the LRU entry in O(1). + self.counts: OrderedDict[OffloadKey, int] = OrderedDict() + + # ------------------------------------------------------------------ + # Intercepted methods + # ------------------------------------------------------------------ + + def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None: + """Record the key, then delegate lookup to backing manager.""" + if key in self.counts: + self.counts.move_to_end(key) + self.counts[key] += 1 + else: + if len(self.counts) >= self.max_tracker_size: + self.counts.popitem(last=False) # evict LRU + self.counts[key] = 1 + return self._backing.lookup(key, req_context) + + def prepare_store( + self, keys: Iterable[OffloadKey], req_context: ReqContext + ) -> PrepareStoreOutput | None: + """Filter out blocks below threshold, then delegate to backing. + + Filtering is evaluated *before* calling the backing manager's + ``prepare_store`` so that blocks that would be skipped do not + consume any CPU offload capacity. + """ + keys = list(keys) + eligible = [ + key for key in keys if self.counts.get(key, 0) >= self.store_threshold + ] + + # Passing an empty list is intentional and safe — CPUOffloadingManager + # handles it correctly, returning a PrepareStoreOutput with empty lists. + # Delegate to the backing manager with only the eligible keys. + return self._backing.prepare_store(eligible, req_context) + + # ------------------------------------------------------------------ + # Delegated methods + # ------------------------------------------------------------------ + + def prepare_load( + self, keys: Iterable[OffloadKey], req_context: ReqContext + ) -> LoadStoreSpec: + return self._backing.prepare_load(keys, req_context) + + def touch(self, keys: Iterable[OffloadKey]) -> None: + return self._backing.touch(keys) + + def complete_load(self, keys: Iterable[OffloadKey]) -> None: + return self._backing.complete_load(keys) + + def complete_store(self, keys: Iterable[OffloadKey], success: bool = True) -> None: + return self._backing.complete_store(keys, success) + + def take_events(self) -> Iterable[OffloadingEvent]: + return self._backing.take_events() diff --git a/vllm/v1/kv_offload/cpu/spec.py b/vllm/v1/kv_offload/cpu/spec.py index 54046d98f452..2fa493d429a9 100644 --- a/vllm/v1/kv_offload/cpu/spec.py +++ b/vllm/v1/kv_offload/cpu/spec.py @@ -6,6 +6,7 @@ from vllm.platforms import current_platform from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_offload.base import ( + BlockIDsLoadStoreSpec, CanonicalKVCaches, GPULoadStoreSpec, LoadStoreSpec, diff --git a/vllm/v1/kv_offload/reuse_manager.py b/vllm/v1/kv_offload/reuse_manager.py deleted file mode 100644 index 2e556ca8d054..000000000000 --- a/vllm/v1/kv_offload/reuse_manager.py +++ /dev/null @@ -1,118 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Reuse-frequency gating for CPU KV-cache offload stores. - -FilterReusedOffloadingManager — OffloadingManager decorator that skips - storing blocks that have not yet been seen enough times. -""" - -from collections import OrderedDict -from collections.abc import Iterable, Sequence - -from vllm.v1.kv_offload.base import ( - LoadStoreSpec, - OffloadingEvent, - OffloadingManager, - OffloadKey, - PrepareStoreOutput, - ReqContext, -) - - -class FilterReusedOffloadingManager(OffloadingManager): - """An :class:`OffloadingManager` decorator that skips storing blocks - whose reuse frequency is below *store_threshold*. - - All methods are delegated to the *backing* manager. Two methods are - intercepted: - - * ``prepare_store`` — filters out keys that have not yet - * ``lookup`` — records the visited key in an internal LRU - counter, then delegates to the backing manager. - crossed the threshold *before* calling the backing - ``prepare_store``. - - Args: - backing: The underlying ``OffloadingManager`` to delegate to. - store_threshold: A block must be seen at least this many times in - ``lookup()`` before it is eligible for offloading. Must be >= 2 - (a value of 1 would be equivalent to no filtering). - max_tracker_size: Maximum entries in the internal tracker's LRU table. - """ - - def __init__( - self, - backing: OffloadingManager, - store_threshold: int = 2, - max_tracker_size: int = 64_000, - ): - if store_threshold < 2: - raise ValueError( - "FilterReusedOffloadingManager store_threshold must be >= 2, " - f"got {store_threshold}" - ) - if max_tracker_size < 1: - raise ValueError( - "FilterReusedOffloadingManager max_tracker_size must be >= 1, " - f"got {max_tracker_size}" - ) - self._backing = backing - self.store_threshold = store_threshold - self.max_tracker_size = max_tracker_size - # Ordered so we can evict the LRU entry in O(1). - self.counts: OrderedDict[OffloadKey, int] = OrderedDict() - - # ------------------------------------------------------------------ - # Intercepted methods - # ------------------------------------------------------------------ - - def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None: - """Record the key, then delegate lookup to backing manager.""" - if key in self.counts: - self.counts.move_to_end(key) - self.counts[key] += 1 - else: - if len(self.counts) >= self.max_tracker_size: - self.counts.popitem(last=False) # evict LRU - self.counts[key] = 1 - return self._backing.lookup(key, req_context) - - def prepare_store( - self, keys: Sequence[OffloadKey], req_context: ReqContext - ) -> PrepareStoreOutput | None: - """Filter out blocks below threshold, then delegate to backing. - - Filtering is evaluated *before* calling the backing manager's - ``prepare_store`` so that blocks that would be skipped do not - consume any CPU offload capacity. - """ - eligible = [ - key for key in keys if self.counts.get(key, 0) >= self.store_threshold - ] - - # Passing an empty list is intentional and safe — CPUOffloadingManager - # handles it correctly, returning a PrepareStoreOutput with empty lists. - # Delegate to the backing manager with only the eligible keys. - return self._backing.prepare_store(eligible, req_context) - - # ------------------------------------------------------------------ - # Delegated methods - # ------------------------------------------------------------------ - - def prepare_load( - self, keys: Sequence[OffloadKey], req_context: ReqContext - ) -> LoadStoreSpec: - return self._backing.prepare_load(keys, req_context) - - def touch(self, keys: Sequence[OffloadKey]) -> None: - return self._backing.touch(keys) - - def complete_load(self, keys: Iterable[OffloadKey]) -> None: - return self._backing.complete_load(keys) - - def complete_store(self, keys: Iterable[OffloadKey], success: bool = True) -> None: - return self._backing.complete_store(keys, success) - - def take_events(self) -> Iterable[OffloadingEvent]: - return self._backing.take_events() From 3133c689098aeacaeb6f2df7c160a363022793e8 Mon Sep 17 00:00:00 2001 From: Martin Hickey Date: Wed, 22 Apr 2026 12:01:28 +0100 Subject: [PATCH 2/4] Fix fragile circular dependency Code review from gemini-code-assist to avoid potential circular dependency with `vllm.v1.kv_offload.cpu.manager` to add the manager imports to the relevant methods. https://github.com/vllm-project/vllm/pull/40538#discussion_r3119357634 Signed-off-by: Martin Hickey --- vllm/v1/kv_offload/cpu/spec.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/v1/kv_offload/cpu/spec.py b/vllm/v1/kv_offload/cpu/spec.py index 2fa493d429a9..17b77da8d502 100644 --- a/vllm/v1/kv_offload/cpu/spec.py +++ b/vllm/v1/kv_offload/cpu/spec.py @@ -57,6 +57,11 @@ def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig): def get_manager(self) -> OffloadingManager: if not self._manager: + from vllm.v1.kv_offload.cpu.manager import ( + CPUOffloadingManager, + FilterReusedOffloadingManager, + ) + kv_events_config = self.vllm_config.kv_events_config enable_events = ( kv_events_config is not None and kv_events_config.enable_kv_cache_events From fec97b67a99c9f10bda6f646008edfee86b46407 Mon Sep 17 00:00:00 2001 From: Martin Hickey Date: Wed, 22 Apr 2026 15:41:49 +0100 Subject: [PATCH 3/4] [kv_offload] Decouple store policy and request lifecycle from the scheduler This commit helps separate three concerns that were previously conflated in `OffloadingConnectorScheduler`: transfer lifecycle, store policy, and request teardown. It achievees this by: - Adding a request_finished lifecycle hook to OffloadingManager so the scheduler can ask the manager whether GPU blocks are safe to free, rather than maintaining that knowledge itself via an inline dict check. - Extracting the hardcoded store-on-compute logic into a pluggable interface `OffloadPolicy`. `StoreOnComputePolicy` deomonstrates the ability to add future policies (preemption-only, spillover) which can be injected at construction with no scheduler changes. - Moving the per-request store watermark out of the general-purpose `RequestKVState` struct and into the policy that owns it. Signed-off-by: Martin Hickey --- .../kv_connector/v1/offloading/policy.py | 195 ++++++++++++++++++ .../kv_connector/v1/offloading/scheduler.py | 46 +++-- vllm/v1/kv_offload/base.py | 12 ++ vllm/v1/kv_offload/cpu/manager.py | 6 + 4 files changed, 246 insertions(+), 13 deletions(-) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/offloading/policy.py diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/policy.py new file mode 100644 index 000000000000..7ed3926650e9 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/policy.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import ReqId +from vllm.logger import init_logger +from vllm.v1.kv_offload.base import GPULoadStoreSpec, OffloadingManager, OffloadKey +from vllm.v1.kv_offload.worker.worker import TransferSpec + +if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.v1.offloading.scheduler import ( + RequestKVState, + SchedulerOffloadConfig, + ) + from vllm.v1.core.sched.output import SchedulerOutput + +logger = init_logger(__name__) + + +class OffloadPolicy(ABC): + """ + Decides which KV cache blocks to offload each scheduler step. + + Implementations may store per-request state but must clean it up + via ``request_finished``. + """ + + @abstractmethod + def get_blocks_to_store( + self, + req_kv_states: dict[str, RequestKVState], + scheduler_output: SchedulerOutput, + config: SchedulerOffloadConfig, + manager: OffloadingManager, + reqs_being_stored: dict[ReqId, set[OffloadKey]], + ) -> dict[ReqId, TransferSpec]: + """ + Decide which blocks to store this scheduler step. + + Args: + req_kv_states: per-request KV tracking state. + scheduler_output: the current scheduler output. + config: offloading configuration. + manager: the offloading manager to call prepare_store on. + reqs_being_stored: scheduler-owned dict of in-flight store keys, + updated in-place for each request that gets a store queued. + + Returns: + A dict mapping request ID to the TransferSpec to submit. + """ + pass + + @abstractmethod + def request_finished(self, req_id: str) -> None: + """Clean up per-request policy state on request completion.""" + pass + + @abstractmethod + def notify_load_scheduled( + self, req_id: str, next_block_idx_per_group: list[int] + ) -> None: + """ + Advance the store watermark when blocks are scheduled for loading, + preventing the policy from re-storing blocks already being loaded. + + Args: + req_id: the request whose watermark to advance. + next_block_idx_per_group: per-group block count up to which a + load has been scheduled. + """ + pass + + +class StoreOnComputePolicy(OffloadPolicy): + """ + Store blocks immediately as they are computed. + + This is the default policy: each scheduler step it identifies newly + computed full offload-blocks and queues them for transfer to the + offload medium. + """ + + def __init__(self) -> None: + # req_id -> per-group index of the next block that needs to be stored + self._next_stored_block_idx: dict[str, list[int]] = {} + + def get_blocks_to_store( + self, + req_kv_states: dict[str, RequestKVState], + scheduler_output: SchedulerOutput, + config: SchedulerOffloadConfig, + manager: OffloadingManager, + reqs_being_stored: dict[ReqId, set[OffloadKey]], + ) -> dict[ReqId, TransferSpec]: + # Below assertion will be removed once this function supports HMA + assert len(config.kv_group_configs) == 1 + group_config = config.kv_group_configs[0] + + reqs_to_store: dict[ReqId, TransferSpec] = {} + for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output): + req_kv_state = req_kv_states[req_id] + req_kv_state.update_offload_keys() + + if preempted: + for group_state in req_kv_state.group_states: + group_state.block_ids.clear() + + if new_block_id_groups: + req_kv_state.update_block_id_groups(new_block_id_groups) + + # Below assertion will be removed once this function supports HMA + assert len(req_kv_state.group_states) == 1 + group_state = req_kv_state.group_states[0] + + block_ids = group_state.block_ids + + req = req_kv_state.req + new_tokens = scheduler_output.num_scheduled_tokens[req_id] + expected_tokens = req.num_computed_tokens + new_tokens + total_tokens = min(expected_tokens, req.num_tokens) + num_blocks = total_tokens // group_config.offloaded_block_size + + if req_id not in self._next_stored_block_idx: + self._next_stored_block_idx[req_id] = [0] * len( + req_kv_state.group_states + ) + start_block_idx = self._next_stored_block_idx[req_id][0] + num_new_blocks = num_blocks - start_block_idx + + if num_new_blocks <= 0: + continue + + num_gpu_blocks = num_blocks * config.block_size_factor + assert len(req.block_hashes) >= num_gpu_blocks + + new_offload_keys = group_state.offload_keys[start_block_idx:num_blocks] + store_output = manager.prepare_store( + new_offload_keys, req_kv_state.req_context + ) + if store_output is None: + logger.warning( + "Request %s: cannot store %s blocks", req_id, num_new_blocks + ) + continue + + self._next_stored_block_idx[req_id][0] = num_blocks + + if not store_output.keys_to_store: + continue + keys_to_store = set(store_output.keys_to_store) + + manager.touch(group_state.offload_keys[:num_blocks]) + + dst_spec = store_output.store_spec + src_block_ids: list[int] = [] + for idx, key in enumerate(new_offload_keys): + if key not in keys_to_store: + continue + offloaded_block_idx = start_block_idx + idx + gpu_block_idx = offloaded_block_idx * config.block_size_factor + for i in range(config.block_size_factor): + src_block_ids.append(block_ids[gpu_block_idx + i]) + src_spec = GPULoadStoreSpec( + src_block_ids, + group_sizes=(len(src_block_ids),), + block_indices=(0,), + ) + + reqs_to_store[req_id] = (src_spec, dst_spec) + reqs_being_stored[req_id] |= keys_to_store + + logger.debug( + "Request %s offloading %s blocks starting from block #%d", + req_id, + len(keys_to_store), + start_block_idx, + ) + + return reqs_to_store + + def request_finished(self, req_id: str) -> None: + self._next_stored_block_idx.pop(req_id, None) + + def notify_load_scheduled( + self, req_id: str, next_block_idx_per_group: list[int] + ) -> None: + state = self._next_stored_block_idx.setdefault( + req_id, [0] * len(next_block_idx_per_group) + ) + for i, val in enumerate(next_block_idx_per_group): + state[i] = max(state[i], val) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py index f437782070df..ec9995055173 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py @@ -6,7 +6,6 @@ from typing import Any, NamedTuple from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent -from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import ( OffloadingConnectorMetadata, @@ -14,6 +13,10 @@ ReqId, TransferJob, ) +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.policy import ( + OffloadPolicy, + StoreOnComputePolicy, +) from vllm.logger import init_logger from vllm.utils.math_utils import cdiv from vllm.v1.core.kv_cache_manager import KVCacheBlocks @@ -84,12 +87,10 @@ def from_spec(cls, spec: OffloadingSpec) -> "SchedulerOffloadConfig": class RequestGroupState: offload_keys: list[OffloadKey] = field(default_factory=list) block_ids: list[int] = field(default_factory=list) - # index of next block (of size offloaded_block_size) to offload - next_stored_block_idx: int = 0 @dataclass(slots=True) -class RequestOffloadState: +class RequestKVState: config: SchedulerOffloadConfig req: Request group_states: tuple[RequestGroupState, ...] = field(init=False) @@ -143,7 +144,7 @@ def advance_stored_idx(self, num_offloadable_tokens: int) -> None: class OffloadingConnectorScheduler: """Implementation of Scheduler side methods""" - def __init__(self, spec: OffloadingSpec): + def __init__(self, spec: OffloadingSpec, policy: OffloadPolicy | None = None): self.config = SchedulerOffloadConfig.from_spec(spec) self.manager: OffloadingManager = spec.get_manager() @@ -154,7 +155,8 @@ def __init__(self, spec: OffloadingSpec): self.lookup_groups = attention_groups - self._req_status: dict[ReqId, RequestOffloadState] = {} + self._policy: OffloadPolicy = policy or StoreOnComputePolicy() + self._req_kv_states: dict[ReqId, RequestKVState] = {} self._current_batch_load_jobs: dict[int, TransferJob] = {} self._current_batch_jobs_to_flush: set[int] = set() # if GPU prefix caching is enabled, @@ -241,13 +243,13 @@ def get_num_new_matched_tokens( - `True` if tokens will be loaded asynchronously (between scheduler steps). """ - if req_status := self._req_status.get(request.request_id): + if req_status := self._req_kv_states.get(request.request_id): # make sure block IDs are cleared for group_state in req_status.group_states: group_state.block_ids.clear() else: - req_status = RequestOffloadState(config=self.config, req=request) - self._req_status[request.request_id] = req_status + req_status = RequestKVState(config=self.config, req=request) + self._req_kv_states[request.request_id] = req_status req_status.update_offload_keys() req_status.num_locally_computed_tokens = num_computed_tokens @@ -338,8 +340,8 @@ def update_state_after_alloc( if num_external_tokens == 0: return - req_status = self._req_status[request.request_id] - + req_status = self._req_kv_states[request.request_id] + num_locally_computed_tokens = req_status.num_locally_computed_tokens num_cached_tokens = num_locally_computed_tokens + num_external_tokens @@ -428,6 +430,8 @@ def update_state_after_alloc( keys=set(keys_to_load), is_store=False, ) + # MH TODO: what happen here now? + # self._policy.notify_load_scheduled(request.request_id, [num_blocks]) if self._blocks_being_loaded is not None: self._blocks_being_loaded.update(keys_to_load) @@ -583,6 +587,16 @@ def _build_store_jobs( ) return store_jobs + + # MH TODO: What to do with this now as _build_store_jobs() replaced this func + """def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): + return self._policy.get_blocks_to_store( + self._req_kv_states, + scheduler_output, + self.config, + self.manager, + self._reqs_being_stored, + )""" def build_connector_meta( self, scheduler_output: SchedulerOutput @@ -661,11 +675,11 @@ def request_finished( """ # TODO(orozery): possibly kickoff offload for last block # which may have been deferred due to async scheduling - req_status = self._req_status.get(request.request_id) + req_status = self._req_kv_states.get(request.request_id) if req_status is None: return False, None if not req_status.transfer_jobs: - del self._req_status[request.request_id] + del self._req_kv_states[request.request_id] return False, None # Pending stores will outlive the request's block ownership. # Register them so future block reuse triggers a flush. @@ -675,6 +689,12 @@ def request_finished( self._block_id_to_pending_jobs.setdefault(bid, set()).add(job_id) return False, None + # MH TODO: How to merge this or how does it work now? + """self._req_kv_states.pop(req_id, None) + self._policy.request_finished(req_id) + + return self.manager.request_finished(req_id), None""" + def take_events(self) -> Iterable[KVCacheEvent]: """Take the KV cache events from the connector. diff --git a/vllm/v1/kv_offload/base.py b/vllm/v1/kv_offload/base.py index b7320cf2bf09..c7609a2992dc 100644 --- a/vllm/v1/kv_offload/base.py +++ b/vllm/v1/kv_offload/base.py @@ -223,6 +223,18 @@ def take_events(self) -> Iterable[OffloadingEvent]: """ return () + def request_finished(self, req_id: str) -> bool: + """ + Called when a request has finished, before its GPU blocks are freed. + + Returns: + True if the manager is still performing async work for this + request (e.g. an in-flight store) and GPU blocks must not yet + be freed. The scheduler will wait until the transfer completes + before releasing blocks. + """ + return False + def shutdown(self) -> None: """Shutdown the manager and release any resources.""" return diff --git a/vllm/v1/kv_offload/cpu/manager.py b/vllm/v1/kv_offload/cpu/manager.py index 76613444ad33..770981e6dbec 100644 --- a/vllm/v1/kv_offload/cpu/manager.py +++ b/vllm/v1/kv_offload/cpu/manager.py @@ -197,6 +197,9 @@ def take_events(self) -> Iterable[OffloadingEvent]: yield from self.events self.events.clear() + def request_finished(self, req_id: str) -> bool: + return False + # ----------------------------------------------------------------------------- # FilterReusedOffloadingManager — reuse-frequency gating for CPU offload stores @@ -300,3 +303,6 @@ def complete_store(self, keys: Iterable[OffloadKey], success: bool = True) -> No def take_events(self) -> Iterable[OffloadingEvent]: return self._backing.take_events() + + def request_finished(self, req_id: str) -> bool: + return self._backing.request_finished(req_id) From f6213ea9e54fcdde0dff5391fe0efbb2da05652e Mon Sep 17 00:00:00 2001 From: Martin Hickey Date: Thu, 30 Apr 2026 12:47:10 +0100 Subject: [PATCH 4/4] Refactor after changes to scheduling in main Signed-off-by: Martin Hickey --- tests/v1/kv_offload/cpu/test_manager.py | 6 +- .../kv_connector/v1/offloading/policy.py | 188 ++++++++++------ .../kv_connector/v1/offloading/scheduler.py | 209 +++++------------- vllm/v1/kv_offload/base.py | 12 - vllm/v1/kv_offload/cpu/manager.py | 5 +- vllm/v1/kv_offload/cpu/spec.py | 3 - 6 files changed, 182 insertions(+), 241 deletions(-) diff --git a/tests/v1/kv_offload/cpu/test_manager.py b/tests/v1/kv_offload/cpu/test_manager.py index ef5d61e7b3d2..bc9aed999c4b 100644 --- a/tests/v1/kv_offload/cpu/test_manager.py +++ b/tests/v1/kv_offload/cpu/test_manager.py @@ -15,9 +15,11 @@ make_offload_key, ) from vllm.v1.kv_offload.cpu.common import CPULoadStoreSpec -from vllm.v1.kv_offload.cpu.manager import CPUOffloadingManager +from vllm.v1.kv_offload.cpu.manager import ( + CPUOffloadingManager, + FilterReusedOffloadingManager, +) from vllm.v1.kv_offload.cpu.policies.arc import ARCCachePolicy -from vllm.v1.kv_offload.reuse_manager import FilterReusedOffloadingManager def make_req_context(kv_transfer_params: dict | None = None) -> ReqContext: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/policy.py index 7ed3926650e9..84e3e636bdcf 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/policy.py @@ -3,13 +3,12 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import TYPE_CHECKING -from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import ReqId from vllm.logger import init_logger from vllm.v1.kv_offload.base import GPULoadStoreSpec, OffloadingManager, OffloadKey -from vllm.v1.kv_offload.worker.worker import TransferSpec if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.v1.offloading.scheduler import ( @@ -17,10 +16,21 @@ SchedulerOffloadConfig, ) from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.kv_offload.base import LoadStoreSpec logger = init_logger(__name__) +@dataclass +class StorePlanEntry: + """Store decision for one request returned by OffloadPolicy.""" + + src_spec: GPULoadStoreSpec + dst_spec: LoadStoreSpec + keys: set[OffloadKey] + gpu_block_ids: list[int] + + class OffloadPolicy(ABC): """ Decides which KV cache blocks to offload each scheduler step. @@ -36,21 +46,25 @@ def get_blocks_to_store( scheduler_output: SchedulerOutput, config: SchedulerOffloadConfig, manager: OffloadingManager, - reqs_being_stored: dict[ReqId, set[OffloadKey]], - ) -> dict[ReqId, TransferSpec]: + ) -> dict[ReqId, StorePlanEntry]: """ Decide which blocks to store this scheduler step. + Called after the scheduler has applied block-ID updates and fence + checks for the current step. Implementations read the already-updated + ``req_kv_states`` and ``scheduler_output.num_scheduled_tokens`` to + determine which blocks are newly computable and eligible for transfer. + Args: - req_kv_states: per-request KV tracking state. + req_kv_states: per-request KV tracking state (block IDs already + updated by the caller for this step). scheduler_output: the current scheduler output. config: offloading configuration. manager: the offloading manager to call prepare_store on. - reqs_being_stored: scheduler-owned dict of in-flight store keys, - updated in-place for each request that gets a store queued. Returns: - A dict mapping request ID to the TransferSpec to submit. + A dict mapping request ID to a StorePlanEntry describing the + transfer to submit. """ pass @@ -94,90 +108,130 @@ def get_blocks_to_store( scheduler_output: SchedulerOutput, config: SchedulerOffloadConfig, manager: OffloadingManager, - reqs_being_stored: dict[ReqId, set[OffloadKey]], - ) -> dict[ReqId, TransferSpec]: - # Below assertion will be removed once this function supports HMA - assert len(config.kv_group_configs) == 1 - group_config = config.kv_group_configs[0] - - reqs_to_store: dict[ReqId, TransferSpec] = {} - for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output): - req_kv_state = req_kv_states[req_id] - req_kv_state.update_offload_keys() - - if preempted: - for group_state in req_kv_state.group_states: - group_state.block_ids.clear() - - if new_block_id_groups: - req_kv_state.update_block_id_groups(new_block_id_groups) - - # Below assertion will be removed once this function supports HMA - assert len(req_kv_state.group_states) == 1 - group_state = req_kv_state.group_states[0] - - block_ids = group_state.block_ids + ) -> dict[ReqId, StorePlanEntry]: + block_size_factor = config.block_size_factor + reqs_to_store: dict[ReqId, StorePlanEntry] = {} + for req_id in scheduler_output.num_scheduled_tokens: + req_kv_state = req_kv_states.get(req_id) + if req_kv_state is None: + continue + req_kv_state.update_offload_keys() req = req_kv_state.req - new_tokens = scheduler_output.num_scheduled_tokens[req_id] - expected_tokens = req.num_computed_tokens + new_tokens - total_tokens = min(expected_tokens, req.num_tokens) - num_blocks = total_tokens // group_config.offloaded_block_size - if req_id not in self._next_stored_block_idx: - self._next_stored_block_idx[req_id] = [0] * len( - req_kv_state.group_states - ) - start_block_idx = self._next_stored_block_idx[req_id][0] - num_new_blocks = num_blocks - start_block_idx + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_tokens_after_batch = req.num_computed_tokens + num_scheduled_tokens + num_offloadable_tokens = min(num_tokens_after_batch, req.num_tokens) - if num_new_blocks <= 0: + if req_id not in self._next_stored_block_idx: + self._next_stored_block_idx[req_id] = [ + 0 for _ in config.kv_group_configs + ] + watermark = self._next_stored_block_idx[req_id] + + # Collect eligible offload keys across all groups, filtering out + # blocks skipped due to sliding window attention or SSM. + new_offload_keys: list[OffloadKey] = [] + for group_idx, (group_config, group_state) in enumerate( + zip(config.kv_group_configs, req_kv_state.group_states) + ): + num_blocks = num_offloadable_tokens // group_config.offloaded_block_size + start_block_idx = watermark[group_idx] + if num_blocks <= start_block_idx: + continue + offload_keys = group_state.offload_keys[start_block_idx:num_blocks] + # Take the last GPU block of each offloaded block to determine + # whether the block was skipped (block_id == 0). + offload_block_ids = group_state.block_ids[ + start_block_idx * block_size_factor + + block_size_factor + - 1 : num_blocks * block_size_factor : block_size_factor + ] + assert len(offload_keys) == len(offload_block_ids) + for offload_key, block_id in zip(offload_keys, offload_block_ids): + if block_id != 0: + new_offload_keys.append(offload_key) + + if not new_offload_keys: + # No new blocks to store; advance the watermark. + for group_idx, group_config in enumerate(config.kv_group_configs): + num_blocks = ( + num_offloadable_tokens // group_config.offloaded_block_size + ) + watermark[group_idx] = max(watermark[group_idx], num_blocks) continue - num_gpu_blocks = num_blocks * config.block_size_factor - assert len(req.block_hashes) >= num_gpu_blocks - - new_offload_keys = group_state.offload_keys[start_block_idx:num_blocks] store_output = manager.prepare_store( new_offload_keys, req_kv_state.req_context ) if store_output is None: - logger.warning( - "Request %s: cannot store %s blocks", req_id, num_new_blocks - ) + logger.warning("Request %s: cannot store blocks", req_id) continue - self._next_stored_block_idx[req_id][0] = num_blocks - if not store_output.keys_to_store: + # Manager declined; advance the watermark. + for group_idx, group_config in enumerate(config.kv_group_configs): + num_blocks = ( + num_offloadable_tokens // group_config.offloaded_block_size + ) + watermark[group_idx] = max(watermark[group_idx], num_blocks) continue - keys_to_store = set(store_output.keys_to_store) - manager.touch(group_state.offload_keys[:num_blocks]) + for group_state in req_kv_state.group_states: + manager.touch(group_state.offload_keys) - dst_spec = store_output.store_spec + keys_to_store = set(store_output.keys_to_store) + + group_sizes: list[int] = [] + block_indices: list[int] = [] src_block_ids: list[int] = [] - for idx, key in enumerate(new_offload_keys): - if key not in keys_to_store: - continue - offloaded_block_idx = start_block_idx + idx - gpu_block_idx = offloaded_block_idx * config.block_size_factor - for i in range(config.block_size_factor): - src_block_ids.append(block_ids[gpu_block_idx + i]) + for group_idx, (group_config, group_state) in enumerate( + zip(config.kv_group_configs, req_kv_state.group_states) + ): + num_blocks = num_offloadable_tokens // group_config.offloaded_block_size + start_block_idx = watermark[group_idx] + block_ids = group_state.block_ids + num_group_blocks = 0 + start_gpu_block_idx: int | None = None + for idx, offload_key in enumerate( + group_state.offload_keys[start_block_idx:num_blocks] + ): + if offload_key not in keys_to_store: + continue + offloaded_block_idx = start_block_idx + idx + gpu_block_idx = offloaded_block_idx * block_size_factor + num_group_blocks += block_size_factor + for i in range(block_size_factor): + block_id = block_ids[gpu_block_idx + i] + if block_id == 0: + # Skipped blocks cannot appear after non-skipped blocks. + assert start_gpu_block_idx is None + continue + elif start_gpu_block_idx is None: + start_gpu_block_idx = gpu_block_idx + i + src_block_ids.append(block_id) + group_sizes.append(num_group_blocks) + block_indices.append(start_gpu_block_idx or 0) + watermark[group_idx] = num_blocks + src_spec = GPULoadStoreSpec( src_block_ids, - group_sizes=(len(src_block_ids),), - block_indices=(0,), + group_sizes=tuple(group_sizes), + block_indices=tuple(block_indices), ) + dst_spec = store_output.store_spec - reqs_to_store[req_id] = (src_spec, dst_spec) - reqs_being_stored[req_id] |= keys_to_store + reqs_to_store[req_id] = StorePlanEntry( + src_spec=src_spec, + dst_spec=dst_spec, + keys=keys_to_store, + gpu_block_ids=src_block_ids, + ) logger.debug( - "Request %s offloading %s blocks starting from block #%d", + "Request %s: queuing store for %s blocks", req_id, len(keys_to_store), - start_block_idx, ) return reqs_to_store diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py index ec9995055173..b45ae5171aba 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py @@ -6,6 +6,7 @@ from typing import Any, NamedTuple from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent +from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import ( OffloadingConnectorMetadata, @@ -16,6 +17,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.offloading.policy import ( OffloadPolicy, StoreOnComputePolicy, + StorePlanEntry, ) from vllm.logger import init_logger from vllm.utils.math_utils import cdiv @@ -133,13 +135,6 @@ def update_block_id_groups( for group_state, new_blocks in zip(self.group_states, new_block_id_groups): group_state.block_ids.extend(new_blocks) - def advance_stored_idx(self, num_offloadable_tokens: int) -> None: - for group_config, group_state in zip( - self.config.kv_group_configs, self.group_states - ): - num_blocks = num_offloadable_tokens // group_config.offloaded_block_size - group_state.next_stored_block_idx = num_blocks - class OffloadingConnectorScheduler: """Implementation of Scheduler side methods""" @@ -341,7 +336,7 @@ def update_state_after_alloc( return req_status = self._req_kv_states[request.request_id] - + num_locally_computed_tokens = req_status.num_locally_computed_tokens num_cached_tokens = num_locally_computed_tokens + num_external_tokens @@ -353,6 +348,7 @@ def update_state_after_alloc( # per group group_sizes: list[int] = [] block_indices: list[int] = [] + next_block_idx_per_group: list[int] = [] for group_config, group_state, group_blocks in zip( self.config.kv_group_configs, req_status.group_states, @@ -393,12 +389,15 @@ def update_state_after_alloc( ) group_sizes.append(num_pending_gpu_blocks) block_indices.append(num_locally_computed_gpu_blocks) - - if not do_remote_decode: - # For P/D prefill requests (do_remote_decode=True), we do - # NOT skip saving the hit prefix, as we need to stream the - # entire KV cache so a remote decode node can consume it. - group_state.next_stored_block_idx = num_blocks + next_block_idx_per_group.append(num_blocks) + + if not do_remote_decode: + # For P/D prefill requests (do_remote_decode=True), we do NOT skip + # saving the hit prefix — the entire KV cache must be streamed to + # the remote decode node. + self._policy.notify_load_scheduled( + request.request_id, next_block_idx_per_group + ) # Fence dst blocks against finished-request pending stores. if ( @@ -430,8 +429,6 @@ def update_state_after_alloc( keys=set(keys_to_load), is_store=False, ) - # MH TODO: what happen here now? - # self._policy.notify_load_scheduled(request.request_id, [num_blocks]) if self._blocks_being_loaded is not None: self._blocks_being_loaded.update(keys_to_load) @@ -440,25 +437,20 @@ def _build_store_jobs( self, scheduler_output: SchedulerOutput, ) -> dict[int, TransferJob]: - block_size_factor = self.config.block_size_factor - store_jobs: dict[int, TransferJob] = {} - # iterate over both new and cached requests + # Pre-pass: apply block-ID updates from the scheduler output and check + # fences before delegating the store decision to the policy. for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output): - req_status = self._req_status[req_id] - req_status.update_offload_keys() - req = req_status.req - + req_kv_state = self._req_kv_states.get(req_id) + if req_kv_state is None: + continue if preempted: - for group_state in req_status.group_states: + for group_state in req_kv_state.group_states: group_state.block_ids.clear() - if new_block_id_groups: - req_status.update_block_id_groups(new_block_id_groups) - # Fence new blocks against in-flight stores. + req_kv_state.update_block_id_groups(new_block_id_groups) + # Fence new blocks against in-flight stores from finished requests. if self._block_id_to_pending_jobs: - new_blocks_flat = [ - bid for new_blocks in new_block_id_groups for bid in new_blocks - ] + new_blocks_flat = [bid for gs in new_block_id_groups for bid in gs] if not self._block_id_to_pending_jobs.keys().isdisjoint( new_blocks_flat ): @@ -468,141 +460,46 @@ def _build_store_jobs( for jid in self._block_id_to_pending_jobs.get(bid, ()) ) - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_tokens_after_batch = req.num_computed_tokens + num_scheduled_tokens - # with async scheduling, some tokens may be missing - num_offloadable_tokens = min(num_tokens_after_batch, req.num_tokens) - - # Filter out blocks skipped due to sliding window attention / SSM - new_offload_keys: list[OffloadKey] = [] - for group_config, group_state in zip( - self.config.kv_group_configs, req_status.group_states - ): - num_blocks = num_offloadable_tokens // group_config.offloaded_block_size - start_block_idx = group_state.next_stored_block_idx - if num_blocks <= start_block_idx: - continue - offload_keys = group_state.offload_keys[start_block_idx:num_blocks] - # For each block to offload, take the last corresponding GPU block. - # e.g. if block size factor is 3 and GPU block IDs are - # 1 5 6 7 2 4 9 3 8 then we'll take blocks 6 4 8. - # We will use these GPU blocks to determine if the block needs - # offloading, or (if the GPU block ID is 0) this block should - # be skipped due to sliding window attention / SSM. - # We know that if a block is skipped, then all the previous blocks - # are skipped as well. This is why we take the last of each block. - offload_block_ids = group_state.block_ids[ - start_block_idx * block_size_factor - + block_size_factor - - 1 : num_blocks * block_size_factor : block_size_factor - ] - assert len(offload_keys) == len(offload_block_ids) - - for offload_key, block_id in zip(offload_keys, offload_block_ids): - if block_id != 0: - new_offload_keys.append(offload_key) - - if not new_offload_keys: - req_status.advance_stored_idx(num_offloadable_tokens) - continue - - store_output = self.manager.prepare_store( - new_offload_keys, req_status.req_context - ) - if store_output is None: - logger.warning("Request %s: cannot store blocks", req_id) - continue - - if not store_output.keys_to_store: - req_status.advance_stored_idx(num_offloadable_tokens) - continue - - for group_state in req_status.group_states: - self.manager.touch(group_state.offload_keys) - - keys_to_store = set(store_output.keys_to_store) - - group_sizes: list[int] = [] - block_indices: list[int] = [] - src_block_ids: list[int] = [] - for group_config, group_state in zip( - self.config.kv_group_configs, req_status.group_states - ): - num_blocks = num_offloadable_tokens // group_config.offloaded_block_size - start_block_idx = group_state.next_stored_block_idx - block_ids = group_state.block_ids - num_group_blocks = 0 - start_gpu_block_idx: int | None = None - for idx, offload_key in enumerate( - group_state.offload_keys[start_block_idx:num_blocks] - ): - if offload_key not in keys_to_store: - continue - - offloaded_block_idx = start_block_idx + idx - gpu_block_idx = offloaded_block_idx * block_size_factor - num_group_blocks += block_size_factor - for i in range(block_size_factor): - block_id = block_ids[gpu_block_idx + i] - if block_id == 0: - # skipped blocks cannot appear after non-skipped blocks - assert start_gpu_block_idx is None - continue - elif start_gpu_block_idx is None: - start_gpu_block_idx = gpu_block_idx + i - src_block_ids.append(block_id) - group_sizes.append(num_group_blocks) - block_indices.append(start_gpu_block_idx or 0) - group_state.next_stored_block_idx = num_blocks - - src_spec = GPULoadStoreSpec( - src_block_ids, group_sizes=group_sizes, block_indices=block_indices - ) - dst_spec = store_output.store_spec + # Policy decides which blocks to store this step. + store_plans: dict[str, StorePlanEntry] = self._policy.get_blocks_to_store( + self._req_kv_states, scheduler_output, self.config, self.manager + ) + # Wrap each plan in a TransferJob and register it. + store_jobs: dict[int, TransferJob] = {} + for req_id, plan in store_plans.items(): + req_kv_state = self._req_kv_states[req_id] job_id = self._generate_job_id() - # a store can only be issued when no load is pending. - if req_status.transfer_jobs: - any_jid = next(iter(req_status.transfer_jobs)) + # A store can only be issued when no load is pending. + if req_kv_state.transfer_jobs: + any_jid = next(iter(req_kv_state.transfer_jobs)) assert self._jobs[any_jid].is_store - req_status.transfer_jobs.add(job_id) + req_kv_state.transfer_jobs.add(job_id) self._jobs[job_id] = TransferJobStatus( req_id=req_id, pending_count=self.config.num_workers, - keys=set(keys_to_store), + keys=plan.keys, is_store=True, - gpu_block_ids=src_block_ids, + gpu_block_ids=plan.gpu_block_ids, ) - store_jobs[job_id] = TransferJob( - req_id=req_id, transfer_spec=(src_spec, dst_spec) + req_id=req_id, + transfer_spec=(plan.src_spec, plan.dst_spec), ) - logger.debug( - "Request %s offloading %s blocks upto %d tokens (job %d)", + "Request %s offloading %s blocks (job %d)", req_id, - len(keys_to_store), - num_offloadable_tokens, + len(plan.keys), job_id, ) return store_jobs - - # MH TODO: What to do with this now as _build_store_jobs() replaced this func - """def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): - return self._policy.get_blocks_to_store( - self._req_kv_states, - scheduler_output, - self.config, - self.manager, - self._reqs_being_stored, - )""" def build_connector_meta( self, scheduler_output: SchedulerOutput ) -> KVConnectorMetadata: for req_id in scheduler_output.preempted_req_ids or (): - req_status = self._req_status.get(req_id) + req_status = self._req_kv_states.get(req_id) if req_status is None or not req_status.transfer_jobs: continue any_jid = next(iter(req_status.transfer_jobs)) @@ -645,7 +542,7 @@ def update_connector_output(self, connector_output: KVConnectorOutput): if self._blocks_being_loaded: self._blocks_being_loaded.difference_update(job_status.keys) - req_status = self._req_status[job_status.req_id] + req_status = self._req_kv_states[job_status.req_id] if self._block_id_to_pending_jobs and req_status.req.is_finished(): for bid in job_status.gpu_block_ids or (): pending = self._block_id_to_pending_jobs[bid] @@ -656,7 +553,7 @@ def update_connector_output(self, connector_output: KVConnectorOutput): del self._jobs[job_id] req_status.transfer_jobs.remove(job_id) if not req_status.transfer_jobs and req_status.req.is_finished(): - del self._req_status[job_status.req_id] + del self._req_kv_states[job_status.req_id] def request_finished( self, @@ -675,25 +572,27 @@ def request_finished( """ # TODO(orozery): possibly kickoff offload for last block # which may have been deferred due to async scheduling - req_status = self._req_kv_states.get(request.request_id) + req_id = request.request_id + req_status = self._req_kv_states.get(req_id) if req_status is None: return False, None + + self._policy.request_finished(req_id) + if not req_status.transfer_jobs: - del self._req_kv_states[request.request_id] - return False, None + del self._req_kv_states[req_id] + return self.manager.request_finished(req_id), None + # Pending stores will outlive the request's block ownership. # Register them so future block reuse triggers a flush. for job_id in req_status.transfer_jobs: job_status = self._jobs[job_id] for bid in job_status.gpu_block_ids or (): self._block_id_to_pending_jobs.setdefault(bid, set()).add(job_id) - return False, None - - # MH TODO: How to merge this or how does it work now? - """self._req_kv_states.pop(req_id, None) - self._policy.request_finished(req_id) - return self.manager.request_finished(req_id), None""" + # _req_kv_states[req_id] will be cleaned up in update_connector_output + # once all in-flight transfer jobs complete. + return self.manager.request_finished(req_id), None def take_events(self) -> Iterable[KVCacheEvent]: """Take the KV cache events from the connector. diff --git a/vllm/v1/kv_offload/base.py b/vllm/v1/kv_offload/base.py index c7609a2992dc..357d78bd49da 100644 --- a/vllm/v1/kv_offload/base.py +++ b/vllm/v1/kv_offload/base.py @@ -128,11 +128,7 @@ def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None: @abstractmethod def prepare_load( self, -<<<<<<< HEAD keys: Sequence[OffloadKey], -======= - keys: Iterable[OffloadKey], ->>>>>>> 83731c08b (Refactor KV Offloading) req_context: ReqContext, ) -> LoadStoreSpec: """ @@ -151,11 +147,7 @@ def prepare_load( """ pass -<<<<<<< HEAD def touch(self, keys: Sequence[OffloadKey]): -======= - def touch(self, keys: Iterable[OffloadKey]): ->>>>>>> 83731c08b (Refactor KV Offloading) """ Mark the given blocks as recently used. This could in practice mean moving them to the end of an LRU list. @@ -177,11 +169,7 @@ def complete_load(self, keys: Iterable[OffloadKey]): @abstractmethod def prepare_store( self, -<<<<<<< HEAD keys: Sequence[OffloadKey], -======= - keys: Iterable[OffloadKey], ->>>>>>> 83731c08b (Refactor KV Offloading) req_context: ReqContext, ) -> PrepareStoreOutput | None: """ diff --git a/vllm/v1/kv_offload/cpu/manager.py b/vllm/v1/kv_offload/cpu/manager.py index 770981e6dbec..47e62694b58f 100644 --- a/vllm/v1/kv_offload/cpu/manager.py +++ b/vllm/v1/kv_offload/cpu/manager.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import OrderedDict from collections.abc import Iterable, Sequence from typing import Literal @@ -288,11 +289,11 @@ def prepare_store( # ------------------------------------------------------------------ def prepare_load( - self, keys: Iterable[OffloadKey], req_context: ReqContext + self, keys: Sequence[OffloadKey], req_context: ReqContext ) -> LoadStoreSpec: return self._backing.prepare_load(keys, req_context) - def touch(self, keys: Iterable[OffloadKey]) -> None: + def touch(self, keys: Sequence[OffloadKey]) -> None: return self._backing.touch(keys) def complete_load(self, keys: Iterable[OffloadKey]) -> None: diff --git a/vllm/v1/kv_offload/cpu/spec.py b/vllm/v1/kv_offload/cpu/spec.py index 17b77da8d502..fdd8f2ddfabc 100644 --- a/vllm/v1/kv_offload/cpu/spec.py +++ b/vllm/v1/kv_offload/cpu/spec.py @@ -6,7 +6,6 @@ from vllm.platforms import current_platform from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_offload.base import ( - BlockIDsLoadStoreSpec, CanonicalKVCaches, GPULoadStoreSpec, LoadStoreSpec, @@ -15,8 +14,6 @@ ) from vllm.v1.kv_offload.cpu.common import CPULoadStoreSpec from vllm.v1.kv_offload.cpu.gpu_worker import CpuGpuOffloadingHandlers -from vllm.v1.kv_offload.cpu.manager import CPUOffloadingManager -from vllm.v1.kv_offload.reuse_manager import FilterReusedOffloadingManager from vllm.v1.kv_offload.worker.worker import OffloadingHandler