diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/policy.py new file mode 100644 index 000000000000..13552950aa2e --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/policy.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +from abc import ABC, abstractmethod + +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.state import ( + RequestKVState, + SchedulerOffloadConfig, +) +from vllm.v1.kv_offload.base import OffloadKey + + +class OffloadPolicy(ABC): + """Abstraction for deciding which KV blocks to store each scheduler step.""" + + @abstractmethod + def get_blocks_to_store( + self, + req_kv_state: RequestKVState, + num_offloadable_tokens: int, + ) -> tuple[list[OffloadKey], list[int]]: + """Return (keys_to_store, per_group_start_idx) for this scheduler step. + + The implementation is responsible for tracking per-request progress + and advancing it on every call so the same blocks are not returned + twice. + + Args: + req_kv_state: current KV state of the request. Read-only from + the policy's perspective. + num_offloadable_tokens: token count available for offloading + after this scheduler step. + + Returns: + A 2-tuple of: + - Possibly-empty list of OffloadKey values to store. + - Per-group starting block index (one entry per KV cache group), + indicating where the new keys begin in each group's offload_keys + list. Callers may use this to skip already-processed blocks. + """ + ... + + def on_blocks_loaded( + self, + req_id: str, + num_offloadable_tokens: int, + ) -> None: + """Called when blocks are being loaded so the policy can advance past them. + + Args: + req_id: the request being loaded. + num_offloadable_tokens: token count up to which blocks are loaded. + """ + return + + def request_finished(self, req_id: str) -> None: + """Release any per-request state held by the policy.""" + return + + +class StoreOnComputePolicy(OffloadPolicy): + """Store blocks as soon as they are computed (the default policy). + + Tracks per-request, per-group progress so that each block is submitted + for offloading exactly once, in order. + """ + + def __init__(self, config: SchedulerOffloadConfig) -> None: + self._config = config + self._block_size_factor: int = config.block_size_factor + # req_id -> per-group next stored block index + self._stored_idx: dict[str, list[int]] = {} + + def get_blocks_to_store( + self, + req_kv_state: RequestKVState, + num_offloadable_tokens: int, + ) -> tuple[list[OffloadKey], list[int]]: + req_id = req_kv_state.req.request_id + stored = self._stored_idx.get(req_id) + if stored is None: + stored = [0] * len(self._config.kv_group_configs) + self._stored_idx[req_id] = stored + new_offload_keys: list[OffloadKey] = [] + per_group_start: list[int] = [] + for group_idx, group_config in enumerate(self._config.kv_group_configs): + group_state = req_kv_state.group_states[group_idx] + num_blocks = num_offloadable_tokens // group_config.offloaded_block_size + start_block_idx = stored[group_idx] + per_group_start.append(start_block_idx) + if num_blocks <= start_block_idx: + continue + offload_keys = group_state.offload_keys[start_block_idx:num_blocks] + # For each offloaded block, inspect the last corresponding GPU block. + # A block_id of 0 indicates a sliding-window / SSM padding slot that + # should be skipped; we know all earlier blocks are skipped too. + offload_block_ids = group_state.block_ids[ + start_block_idx * self._block_size_factor + + self._block_size_factor + - 1 : num_blocks * self._block_size_factor : self._block_size_factor + ] + assert len(offload_keys) == len(offload_block_ids) + for offload_key, block_id in zip(offload_keys, offload_block_ids): + if block_id != 0: + new_offload_keys.append(offload_key) + # Always advance regardless of prepare_store filtering later. + stored[group_idx] = num_blocks + return new_offload_keys, per_group_start + + def on_blocks_loaded( + self, + req_id: str, + num_offloadable_tokens: int, + ) -> None: + # Use setdefault so that a load preceding the first store call still + # advances the index, preventing already-loaded blocks from being + # returned by a subsequent get_blocks_to_store call. + stored = self._stored_idx.get(req_id) + if stored is None: + stored = [0] * len(self._config.kv_group_configs) + self._stored_idx[req_id] = stored + for group_idx, group_config in enumerate(self._config.kv_group_configs): + num_blocks = num_offloadable_tokens // group_config.offloaded_block_size + stored[group_idx] = num_blocks + + def request_finished(self, req_id: str) -> None: + self._stored_idx.pop(req_id, None) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py index 773fe8f056ac..ec2d377224c6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Sequence -from dataclasses import dataclass, field -from itertools import islice -from typing import Any, NamedTuple +from dataclasses import dataclass +from typing import Any from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data @@ -14,16 +13,20 @@ ReqId, TransferJob, ) +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.policy import ( + OffloadPolicy, + StoreOnComputePolicy, +) +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.state import ( + GroupOffloadConfig, + RequestGroupState, + RequestKVState, + SchedulerOffloadConfig, +) from vllm.logger import init_logger from vllm.utils.math_utils import cdiv from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import ( - FullAttentionSpec, - KVCacheSpec, - MambaSpec, - SlidingWindowSpec, -) from vllm.v1.kv_offload.base import ( GPULoadStoreSpec, OffloadingManager, @@ -31,7 +34,6 @@ OffloadKey, ReqContext, get_offload_block_hash, - make_offload_key, ) from vllm.v1.outputs import KVConnectorOutput from vllm.v1.request import Request @@ -58,136 +60,13 @@ class TransferJobStatus: sliding_window_block_ids: list[int] | None = None -class GroupOffloadConfig(NamedTuple): - group_idx: int - gpu_block_size: int - offloaded_block_size: int - hash_block_size_factor: int - # None below means full attention - sliding_window_size_in_blocks: int | None - - -def get_sliding_window_size_in_blocks( - kv_cache_spec: KVCacheSpec, offloaded_block_size: int -) -> int | None: - if isinstance(kv_cache_spec, SlidingWindowSpec): - assert kv_cache_spec.sliding_window > 0 - return cdiv(kv_cache_spec.sliding_window, offloaded_block_size) - - if isinstance(kv_cache_spec, MambaSpec): - # Mamba depends on a single state - return 1 - - assert isinstance(kv_cache_spec, FullAttentionSpec) - return None - - -class SchedulerOffloadConfig(NamedTuple): - kv_group_configs: tuple[GroupOffloadConfig, ...] - block_size_factor: int - num_workers: int - - @classmethod - def from_spec(cls, spec: OffloadingSpec) -> "SchedulerOffloadConfig": - return cls( - num_workers=spec.vllm_config.parallel_config.world_size, - kv_group_configs=tuple( - GroupOffloadConfig( - group_idx=idx, - gpu_block_size=gpu_block_size, - offloaded_block_size=gpu_block_size * spec.block_size_factor, - hash_block_size_factor=( - (gpu_block_size * spec.block_size_factor) - // spec.hash_block_size - ), - sliding_window_size_in_blocks=get_sliding_window_size_in_blocks( - spec.kv_cache_config.kv_cache_groups[idx].kv_cache_spec, - gpu_block_size * spec.block_size_factor, - ), - ) - for idx, gpu_block_size in enumerate(spec.gpu_block_size) - ), - block_size_factor=spec.block_size_factor, - ) - - -@dataclass -class RequestGroupState: - offload_keys: list[OffloadKey] = field(default_factory=list) - block_ids: list[int] = field(default_factory=list) - # index of next block (of size offloaded_block_size) to offload - next_stored_block_idx: int = 0 - # number of offloaded blocks hit (including GPU prefix cache) - # when the request first started - num_hit_blocks: int = 0 - - -@dataclass(slots=True) -class RequestOffloadState: - config: SchedulerOffloadConfig - req: Request - group_states: tuple[RequestGroupState, ...] = field(init=False) - req_context: ReqContext = field(init=False) - # number of hits in the GPU cache - num_locally_computed_tokens: int = 0 - # In-flight job IDs. Per the connector's invariant, at any given time - # this contains either a single load job, or one or more store jobs. - transfer_jobs: set[int] = field(default_factory=set) - - def __post_init__(self) -> None: - self.group_states = tuple( - RequestGroupState() for _ in self.config.kv_group_configs - ) - self.req_context = ReqContext(kv_transfer_params=self.req.kv_transfer_params) - - def update_offload_keys(self) -> None: - for group_config, group_state in zip( - self.config.kv_group_configs, self.group_states - ): - for req_block_hash in islice( - self.req.block_hashes, - group_config.hash_block_size_factor * len(group_state.offload_keys) - + group_config.hash_block_size_factor - - 1, - None, - group_config.hash_block_size_factor, - ): - group_state.offload_keys.append( - make_offload_key(req_block_hash, group_config.group_idx) - ) - - def update_block_id_groups( - self, new_block_id_groups: tuple[list[int], ...] | None - ) -> None: - if new_block_id_groups is None: - return - - assert len(new_block_id_groups) == len(self.group_states) - for group_state, new_blocks in zip(self.group_states, new_block_id_groups): - group_state.block_ids.extend(new_blocks) - - def advance_stored_idx(self, num_offloadable_tokens: int) -> None: - for group_config, group_state in zip( - self.config.kv_group_configs, self.group_states - ): - num_blocks = num_offloadable_tokens // group_config.offloaded_block_size - group_state.next_stored_block_idx = num_blocks - - def update_num_hit_blocks(self, num_cached_tokens: int) -> None: - for group_config, group_state in zip( - self.config.kv_group_configs, self.group_states - ): - group_state.num_hit_blocks = ( - num_cached_tokens // group_config.offloaded_block_size - ) - - class OffloadingConnectorScheduler: """Implementation of Scheduler side methods""" def __init__(self, spec: OffloadingSpec): self.config = SchedulerOffloadConfig.from_spec(spec) self.manager: OffloadingManager = spec.get_manager() + self.policy: OffloadPolicy = StoreOnComputePolicy(self.config) full_attention_groups: list[int] = [] sliding_window_groups: list[int] = [] @@ -209,7 +88,7 @@ def _sliding_window_sort_key(i: int) -> int: self._sliding_window_groups: tuple[int, ...] = tuple(sliding_window_groups) self._lookup_groups = tuple(full_attention_groups) + self._sliding_window_groups - self._req_status: dict[ReqId, RequestOffloadState] = {} + self._req_status: dict[ReqId, RequestKVState] = {} self._current_batch_load_jobs: dict[int, TransferJob] = {} self._current_batch_jobs_to_flush: set[int] = set() # if GPU prefix caching is enabled, @@ -286,7 +165,7 @@ def _sliding_window_lookup( return idx + sliding_window_size if not defer_lookup else None return consecutive_hits if not defer_lookup else None - def _touch(self, req_status: RequestOffloadState): + def _touch(self, req_status: RequestKVState): for group_config, group_state in zip( self.config.kv_group_configs, req_status.group_states ): @@ -302,7 +181,7 @@ def _touch(self, req_status: RequestOffloadState): ) self.manager.touch(group_state.offload_keys[blocks_to_skip:]) - def _lookup(self, req_status: RequestOffloadState) -> int | None: + def _lookup(self, req_status: RequestKVState) -> int | None: """ Find how many tokens beyond num_locally_computed_tokens can be loaded. @@ -469,7 +348,7 @@ def get_num_new_matched_tokens( group_state.block_ids.clear() else: is_new_request = True - req_status = RequestOffloadState(config=self.config, req=request) + req_status = RequestKVState(config=self.config, req=request) self._req_status[request.request_id] = req_status req_status.update_offload_keys() @@ -552,11 +431,11 @@ def update_state_after_alloc( group_sizes.append(num_pending_gpu_blocks) block_indices.append(num_locally_computed_gpu_blocks) - if not do_remote_decode: - # For P/D prefill requests (do_remote_decode=True), we do - # NOT skip saving the hit prefix, as we need to stream the - # entire KV cache so a remote decode node can consume it. - group_state.next_stored_block_idx = num_blocks + if not do_remote_decode: + # For P/D prefill requests (do_remote_decode=True), we do + # NOT skip saving the hit prefix, as we need to stream the + # entire KV cache so a remote decode node can consume it. + self.policy.on_blocks_loaded(request.request_id, num_cached_tokens) # Fence dst blocks against finished-request pending stores. if ( @@ -629,38 +508,11 @@ def _build_store_jobs( # with async scheduling, some tokens may be missing num_offloadable_tokens = min(num_tokens_after_batch, req.num_tokens) - # Filter out blocks skipped due to sliding window attention / SSM - new_offload_keys: list[OffloadKey] = [] - for group_config, group_state in zip( - self.config.kv_group_configs, req_status.group_states - ): - num_blocks = num_offloadable_tokens // group_config.offloaded_block_size - start_block_idx = group_state.next_stored_block_idx - if num_blocks <= start_block_idx: - continue - offload_keys = group_state.offload_keys[start_block_idx:num_blocks] - # For each block to offload, take the last corresponding GPU block. - # e.g. if block size factor is 3 and GPU block IDs are - # 1 5 6 7 2 4 9 3 8 then we'll take blocks 6 4 8. - # We will use these GPU blocks to determine if the block needs - # offloading, or (if the GPU block ID is 0) this block should - # be skipped due to sliding window attention / SSM. - # We know that if a block is skipped, then all the previous blocks - # are skipped as well. This is why we take the last of each block. - offload_block_ids = group_state.block_ids[ - start_block_idx * block_size_factor - + block_size_factor - - 1 : num_blocks * block_size_factor : block_size_factor - ] - assert len(offload_keys) == len(offload_block_ids) - - for offload_key, block_id in zip(offload_keys, offload_block_ids): - if block_id != 0: - new_offload_keys.append(offload_key) - + new_offload_keys, per_group_start = self.policy.get_blocks_to_store( + req_status, num_offloadable_tokens + ) if not new_offload_keys: - req_status.advance_stored_idx(num_offloadable_tokens) - continue + continue # policy already advanced its index store_output = self.manager.prepare_store( new_offload_keys, req_status.req_context @@ -670,8 +522,7 @@ def _build_store_jobs( continue if not store_output.keys_to_store: - req_status.advance_stored_idx(num_offloadable_tokens) - continue + continue # policy already advanced its index self._touch(req_status) @@ -682,14 +533,14 @@ def _build_store_jobs( src_block_ids: list[int] = [] sliding_window_block_ids: list[int] = [] non_sliding_window_block_ids: list[int] = [] - for group_config, group_state in zip( - self.config.kv_group_configs, req_status.group_states + for group_idx, (group_config, group_state) in enumerate( + zip(self.config.kv_group_configs, req_status.group_states) ): + start_block_idx = per_group_start[group_idx] is_sliding_window = ( group_config.sliding_window_size_in_blocks is not None ) num_blocks = num_offloadable_tokens // group_config.offloaded_block_size - start_block_idx = group_state.next_stored_block_idx block_ids = group_state.block_ids num_group_blocks = 0 start_gpu_block_idx: int | None = None @@ -718,7 +569,6 @@ def _build_store_jobs( group_sizes.append(num_group_blocks) block_indices.append(start_gpu_block_idx or 0) - group_state.next_stored_block_idx = num_blocks src_spec = GPULoadStoreSpec( src_block_ids, group_sizes=group_sizes, block_indices=block_indices @@ -840,11 +690,12 @@ def request_finished( Optional KVTransferParams to be included in the request outputs returned by the engine. """ - # TODO(orozery): possibly kickoff offload for last block - # which may have been deferred due to async scheduling req_status = self._req_status.get(request.request_id) if req_status is None: return False, None + + self.manager.request_finished(req_status.req_context) + self.policy.request_finished(request.request_id) if not req_status.transfer_jobs: del self._req_status[request.request_id] return False, None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/state.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/state.py new file mode 100644 index 000000000000..2d86d8c739d0 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/state.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +from dataclasses import dataclass, field +from itertools import islice +from typing import NamedTuple + +from vllm.utils.math_utils import cdiv +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheSpec, + MambaSpec, + SlidingWindowSpec, +) +from vllm.v1.kv_offload.base import ( + OffloadingSpec, + OffloadKey, + ReqContext, + make_offload_key, +) +from vllm.v1.request import Request + + +class GroupOffloadConfig(NamedTuple): + group_idx: int + gpu_block_size: int + offloaded_block_size: int + hash_block_size_factor: int + # None below means full attention + sliding_window_size_in_blocks: int | None + + +def get_sliding_window_size_in_blocks( + kv_cache_spec: KVCacheSpec, offloaded_block_size: int +) -> int | None: + if isinstance(kv_cache_spec, SlidingWindowSpec): + assert kv_cache_spec.sliding_window > 0 + return cdiv(kv_cache_spec.sliding_window, offloaded_block_size) + + if isinstance(kv_cache_spec, MambaSpec): + # Mamba depends on a single state + return 1 + + assert isinstance(kv_cache_spec, FullAttentionSpec) + return None + + +class SchedulerOffloadConfig(NamedTuple): + kv_group_configs: tuple[GroupOffloadConfig, ...] + block_size_factor: int + num_workers: int + + @classmethod + def from_spec(cls, spec: OffloadingSpec) -> SchedulerOffloadConfig: + return cls( + num_workers=spec.vllm_config.parallel_config.world_size, + kv_group_configs=tuple( + GroupOffloadConfig( + group_idx=idx, + gpu_block_size=gpu_block_size, + offloaded_block_size=gpu_block_size * spec.block_size_factor, + hash_block_size_factor=( + (gpu_block_size * spec.block_size_factor) + // spec.hash_block_size + ), + sliding_window_size_in_blocks=get_sliding_window_size_in_blocks( + spec.kv_cache_config.kv_cache_groups[idx].kv_cache_spec, + gpu_block_size * spec.block_size_factor, + ), + ) + for idx, gpu_block_size in enumerate(spec.gpu_block_size) + ), + block_size_factor=spec.block_size_factor, + ) + + +@dataclass +class RequestGroupState: + offload_keys: list[OffloadKey] = field(default_factory=list) + block_ids: list[int] = field(default_factory=list) + # number of offloaded blocks hit (including GPU prefix cache) + # when the request first started + num_hit_blocks: int = 0 + + +@dataclass(slots=True) +class RequestKVState: + config: SchedulerOffloadConfig + req: Request + group_states: tuple[RequestGroupState, ...] = field(init=False) + req_context: ReqContext = field(init=False) + # number of hits in the GPU cache + num_locally_computed_tokens: int = 0 + # In-flight job IDs. Per the connector's invariant, at any given time + # this contains either a single load job, or one or more store jobs. + transfer_jobs: set[int] = field(default_factory=set) + + def __post_init__(self) -> None: + self.group_states = tuple( + RequestGroupState() for _ in self.config.kv_group_configs + ) + self.req_context = ReqContext(kv_transfer_params=self.req.kv_transfer_params) + + def update_offload_keys(self) -> None: + for group_config, group_state in zip( + self.config.kv_group_configs, self.group_states + ): + for req_block_hash in islice( + self.req.block_hashes, + group_config.hash_block_size_factor * len(group_state.offload_keys) + + group_config.hash_block_size_factor + - 1, + None, + group_config.hash_block_size_factor, + ): + group_state.offload_keys.append( + make_offload_key(req_block_hash, group_config.group_idx) + ) + + def update_block_id_groups( + self, new_block_id_groups: tuple[list[int], ...] | None + ) -> None: + if new_block_id_groups is None: + return + + assert len(new_block_id_groups) == len(self.group_states) + for group_state, new_blocks in zip(self.group_states, new_block_id_groups): + group_state.block_ids.extend(new_blocks) + + def update_num_hit_blocks(self, num_cached_tokens: int) -> None: + for group_config, group_state in zip( + self.config.kv_group_configs, self.group_states + ): + group_state.num_hit_blocks = ( + num_cached_tokens // group_config.offloaded_block_size + ) diff --git a/vllm/v1/kv_offload/base.py b/vllm/v1/kv_offload/base.py index 3d403ea50837..37e65ed1ead2 100644 --- a/vllm/v1/kv_offload/base.py +++ b/vllm/v1/kv_offload/base.py @@ -202,6 +202,15 @@ def complete_store(self, keys: Collection[OffloadKey], success: bool = True): """ return + def request_finished(self, req_context: ReqContext) -> None: + """ + Called by the scheduler when a request has finished. + + Args: + req_context: the context object for the finished request. + """ + return + def take_events(self) -> Iterable[OffloadingEvent]: """ Take the offloading events from the manager. diff --git a/vllm/v1/kv_offload/reuse_manager.py b/vllm/v1/kv_offload/reuse_manager.py index 6cb0a5f7591c..f3a1133488f0 100644 --- a/vllm/v1/kv_offload/reuse_manager.py +++ b/vllm/v1/kv_offload/reuse_manager.py @@ -116,5 +116,8 @@ def complete_store( ) -> None: return self._backing.complete_store(keys, success) + def request_finished(self, req_context: ReqContext) -> None: + return self._backing.request_finished(req_context) + def take_events(self) -> Iterable[OffloadingEvent]: return self._backing.take_events()