From ae864db9fcf6677c68af794de667317f14ae68b3 Mon Sep 17 00:00:00 2001 From: Or Ozeri Date: Wed, 29 Apr 2026 11:47:42 +0300 Subject: [PATCH] [kv_offload+HMA][12/N]: Scheduler-side support for sliding window groups This commit extends the scheduler-side OffloadingConnector to support sliding window and Mamba KV cache groups. Signed-off-by: Or Ozeri --- .../kv_connector/v1/offloading/scheduler.py | 378 +++++++++++++----- 1 file changed, 280 insertions(+), 98 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py index f437782070df..7b0493610896 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py @@ -18,6 +18,12 @@ from vllm.utils.math_utils import cdiv from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheSpec, + MambaSpec, + SlidingWindowSpec, +) from vllm.v1.kv_offload.base import ( GPULoadStoreSpec, OffloadingManager, @@ -44,8 +50,12 @@ class TransferJobStatus: # Offload keys this job covers; passed to manager.complete_*(). keys: set[OffloadKey] is_store: bool - # GPU blocks the fence tracks. Store src blocks; None for loads. - gpu_block_ids: list[int] | None = None + # Store src block IDs whose ref_cnt protects them while the request + # runs. Only registered in _block_id_to_pending_jobs on request_finished. + non_sliding_window_block_ids: list[int] | None = None + # Store src block IDs that may be freed before the request finishes. + # Registered in _block_id_to_pending_jobs at store creation time. + sliding_window_block_ids: list[int] | None = None class GroupOffloadConfig(NamedTuple): @@ -53,6 +63,23 @@ class GroupOffloadConfig(NamedTuple): gpu_block_size: int offloaded_block_size: int hash_block_size_factor: int + # None below means full attention + sliding_window_size_in_blocks: int | None + + +def get_sliding_window_size_in_blocks( + kv_cache_spec: KVCacheSpec, offloaded_block_size: int +) -> int | None: + if isinstance(kv_cache_spec, SlidingWindowSpec): + assert kv_cache_spec.sliding_window > 0 + return cdiv(kv_cache_spec.sliding_window, offloaded_block_size) + + if isinstance(kv_cache_spec, MambaSpec): + # Mamba depends on a single state + return 1 + + assert isinstance(kv_cache_spec, FullAttentionSpec) + return None class SchedulerOffloadConfig(NamedTuple): @@ -73,6 +100,10 @@ def from_spec(cls, spec: OffloadingSpec) -> "SchedulerOffloadConfig": (gpu_block_size * spec.block_size_factor) // spec.hash_block_size ), + sliding_window_size_in_blocks=get_sliding_window_size_in_blocks( + spec.kv_cache_config.kv_cache_groups[idx].kv_cache_spec, + gpu_block_size * spec.block_size_factor, + ), ) for idx, gpu_block_size in enumerate(spec.gpu_block_size) ), @@ -86,6 +117,9 @@ class RequestGroupState: block_ids: list[int] = field(default_factory=list) # index of next block (of size offloaded_block_size) to offload next_stored_block_idx: int = 0 + # number of offloaded blocks hit (including GPU prefix cache) + # when the request first started + num_hit_blocks: int = 0 @dataclass(slots=True) @@ -139,6 +173,14 @@ def advance_stored_idx(self, num_offloadable_tokens: int) -> None: num_blocks = num_offloadable_tokens // group_config.offloaded_block_size group_state.next_stored_block_idx = num_blocks + def update_num_hit_blocks(self, num_cached_tokens: int) -> None: + for group_config, group_state in zip( + self.config.kv_group_configs, self.group_states + ): + group_state.num_hit_blocks = ( + num_cached_tokens // group_config.offloaded_block_size + ) + class OffloadingConnectorScheduler: """Implementation of Scheduler side methods""" @@ -147,12 +189,25 @@ def __init__(self, spec: OffloadingSpec): self.config = SchedulerOffloadConfig.from_spec(spec) self.manager: OffloadingManager = spec.get_manager() - attention_groups: list[int] = [] - for idx, _ in enumerate(spec.kv_cache_config.kv_cache_groups): - # currently treat all groups as full attention - attention_groups.append(idx) + full_attention_groups: list[int] = [] + sliding_window_groups: list[int] = [] + for group_config in self.config.kv_group_configs: + if group_config.sliding_window_size_in_blocks is None: + full_attention_groups.append(group_config.group_idx) + else: + sliding_window_groups.append(group_config.group_idx) + + # sort sliding window groups by window size in decreasing order + def _sliding_window_sort_key(i: int) -> int: + val = self.config.kv_group_configs[i].sliding_window_size_in_blocks + assert val is not None + return val + + sliding_window_groups.sort(key=_sliding_window_sort_key, reverse=True) - self.lookup_groups = attention_groups + # used by _lookup + self._sliding_window_groups: tuple[int, ...] = tuple(sliding_window_groups) + self._lookup_groups = tuple(full_attention_groups) + self._sliding_window_groups self._req_status: dict[ReqId, RequestOffloadState] = {} self._current_batch_load_jobs: dict[int, TransferJob] = {} @@ -167,8 +222,11 @@ def __init__(self, spec: OffloadingSpec): self._job_counter: int = 0 self._jobs: dict[int, TransferJobStatus] = {} - # block_id -> pending store job_ids. Populated only for finished - # requests (running-request blocks are protected by their ref_cnt). + # block_id -> pending store job_ids. Used to track jobs that needs + # flushing in case a block is re-allocated by the KV cache manager. + # Populated only for finished requests (running-request blocks are + # protected by their ref_cnt) and for sliding window blocks (which can + # be freed before a request finishes). self._block_id_to_pending_jobs: dict[int, set[int]] = {} def _generate_job_id(self) -> int: @@ -176,10 +234,18 @@ def _generate_job_id(self) -> int: self._job_counter += 1 return job_id + def _remove_pending_job(self, job_id: int, block_ids: list[int] | None) -> None: + for bid in block_ids or (): + pending = self._block_id_to_pending_jobs[bid] + pending.remove(job_id) + if not pending: + del self._block_id_to_pending_jobs[bid] + def _maximal_prefix_lookup( self, keys: Iterable[OffloadKey], req_context: ReqContext ) -> int | None: - """Find the length of the maximal prefix of offloaded blocks.""" + """Return the number of consecutive offloaded blocks from the start, + or None if the backend deferred a lookup.""" hit_count = 0 defer_lookup = False for key in keys: @@ -200,8 +266,9 @@ def _sliding_window_lookup( sliding_window_size: int, req_context: ReqContext, ) -> int | None: - """Find the maximal ending position of consecutive offloaded blocks - within a sliding window.""" + """Return the end index (in `keys`) of the last run of + `sliding_window_size` consecutive hits, scanning from the end. + Returns 0 on miss, None if the backend deferred a lookup.""" defer_lookup = False consecutive_hits = 0 for idx in range(len(keys) - 1, -1, -1): @@ -219,6 +286,160 @@ def _sliding_window_lookup( return idx + sliding_window_size if not defer_lookup else None return consecutive_hits if not defer_lookup else None + def _touch(self, req_status: RequestOffloadState): + for group_config, group_state in zip( + self.config.kv_group_configs, req_status.group_states + ): + if group_config.sliding_window_size_in_blocks is None: + self.manager.touch(group_state.offload_keys) + else: + # we aim to keep just blocks that are necessary to hit + # the original request (+ decoded blocks) + blocks_to_skip = max( + 0, + group_state.num_hit_blocks + - group_config.sliding_window_size_in_blocks, + ) + self.manager.touch(group_state.offload_keys[blocks_to_skip:]) + + def _lookup(self, req_status: RequestOffloadState) -> int | None: + """ + Find how many tokens beyond num_locally_computed_tokens can be loaded. + + Iterates full-attention groups first (prefix lookup), then sliding-window + groups (suffix lookup). Each group may tighten max_hit_size_tokens, which + can invalidate an earlier group's result, so the loop re-runs when that + happens until num_hit_tokens converges. + """ + num_computed_tokens = req_status.num_locally_computed_tokens + max_hit_size_tokens: int = req_status.req.num_tokens + if self._sliding_window_groups: + # the last prompt token has to be recomputed to get the logprobs + # for sliding window attention, we must reduce by 1 to make sure + # we still have a hit after reduction + max_hit_size_tokens -= 1 + num_hit_tokens: int = 0 + defer_lookup = False + lookup_groups = self._lookup_groups + while lookup_groups: + looked_up_sliding_window: bool = False + groups_iter = iter(lookup_groups) + lookup_groups = () + for group_idx in groups_iter: + group_config: GroupOffloadConfig = self.config.kv_group_configs[ + group_idx + ] + group_state: RequestGroupState = req_status.group_states[group_idx] + offloaded_block_size = group_config.offloaded_block_size + offload_keys = group_state.offload_keys + + assert ( + len(offload_keys) + >= req_status.req.num_tokens // offloaded_block_size + ) + + # Constrain to block-aligned boundary for this group + max_hit_size_tokens = min( + max_hit_size_tokens, len(offload_keys) * offloaded_block_size + ) + if max_hit_size_tokens - num_computed_tokens < offloaded_block_size: + # we can only load less than a block, better skip + return 0 + + num_blocks = min( + cdiv(max_hit_size_tokens, offloaded_block_size), len(offload_keys) + ) + start_block_idx = num_computed_tokens // offloaded_block_size + offload_keys = offload_keys[start_block_idx:num_blocks] + sliding_window_size_in_blocks = ( + group_config.sliding_window_size_in_blocks + ) + + # end index (in the sliced offload_keys) up to which we + # have backend-confirmed hits + num_hit_blocks: int | None + if sliding_window_size_in_blocks is None: + num_hit_blocks = self._maximal_prefix_lookup( + offload_keys, req_status.req_context + ) + else: + num_hit_blocks = self._sliding_window_lookup( + offload_keys, + sliding_window_size_in_blocks, + req_status.req_context, + ) + if num_hit_blocks == 0: + return 0 + + if num_hit_blocks is None: + defer_lookup = True + else: + max_hit_size_tokens = min( + max_hit_size_tokens, + offloaded_block_size * (start_block_idx + num_hit_blocks), + ) + + new_num_hit_tokens = max_hit_size_tokens - num_computed_tokens + if new_num_hit_tokens < offloaded_block_size: + # we can only load less than a block, better skip + return 0 + + if new_num_hit_tokens < num_hit_tokens: + if defer_lookup: + # make another iteration on all groups to check + # if we still need to defer lookup + defer_lookup = False + lookup_groups = self._lookup_groups + elif looked_up_sliding_window and not lookup_groups: + # we need another iteration to confirm previously looked up + # sliding window works with the new_num_hit_tokens + lookup_groups = self._sliding_window_groups + + looked_up_sliding_window |= sliding_window_size_in_blocks is not None + num_hit_tokens = new_num_hit_tokens + + if defer_lookup: + logger.debug( + "Offloading manager delayed request %s as backend requested", + req_status.req.request_id, + ) + return None + + # possibly delay request if any of the hit blocks is already being loaded + if self._blocks_being_loaded: + for group_config, group_state in zip( + self.config.kv_group_configs, req_status.group_states + ): + offloaded_block_size = group_config.offloaded_block_size + sliding_window_size_in_blocks = ( + group_config.sliding_window_size_in_blocks + ) + offload_keys = group_state.offload_keys + num_blocks = cdiv( + num_computed_tokens + num_hit_tokens, offloaded_block_size + ) + start_block_idx = num_computed_tokens // offloaded_block_size + offload_keys = offload_keys[start_block_idx:num_blocks] + if sliding_window_size_in_blocks is not None: + offload_keys = offload_keys[-sliding_window_size_in_blocks:] + if any(key in self._blocks_being_loaded for key in offload_keys): + # hit blocks are being loaded, delay request + logger.debug( + "Delaying request %s since some of its" + " blocks are already being loaded", + req_status.req.request_id, + ) + return None + + logger.debug( + "Request %s hit %s offloaded tokens after %s GPU hit tokens", + req_status.req.request_id, + num_hit_tokens, + num_computed_tokens, + ) + + return num_hit_tokens + def get_num_new_matched_tokens( self, request: Request, num_computed_tokens: int ) -> tuple[int | None, bool]: @@ -241,96 +462,28 @@ def get_num_new_matched_tokens( - `True` if tokens will be loaded asynchronously (between scheduler steps). """ + is_new_request = False if req_status := self._req_status.get(request.request_id): # make sure block IDs are cleared for group_state in req_status.group_states: group_state.block_ids.clear() else: + is_new_request = True req_status = RequestOffloadState(config=self.config, req=request) self._req_status[request.request_id] = req_status req_status.update_offload_keys() req_status.num_locally_computed_tokens = num_computed_tokens - for gs in req_status.group_states: - self.manager.touch(gs.offload_keys) - - # Start with the full request size as the maximum loadable - max_hit_size_tokens: int = req_status.req.num_tokens - num_hit_tokens: int = 0 - defer_lookup = False - delay_request = False - for group_idx in self.lookup_groups: - group_config: GroupOffloadConfig = self.config.kv_group_configs[group_idx] - offloaded_block_size = group_config.offloaded_block_size - offload_keys = req_status.group_states[group_idx].offload_keys - - num_blocks = max_hit_size_tokens // offloaded_block_size - assert len(offload_keys) >= num_blocks - - # Constrain to block-aligned boundary for this group - max_hit_size_tokens = num_blocks * offloaded_block_size - num_hit_tokens = max_hit_size_tokens - num_computed_tokens - if num_hit_tokens < offloaded_block_size: - # we can only load less than a block, better skip - return 0, False - - start_block_idx = num_computed_tokens // offloaded_block_size - offload_keys = offload_keys[start_block_idx:num_blocks] - # Full attention relies on all previous KV cache blocks. - # Thus, we search for a maximal prefix of KV cache which are all cached. - block_hits = self._maximal_prefix_lookup( - offload_keys, req_status.req_context + num_hit_tokens = self._lookup(req_status) + if is_new_request: + req_status.update_num_hit_blocks( + num_computed_tokens + (num_hit_tokens or 0) ) - if block_hits == 0: - return 0, False - if block_hits is None: - defer_lookup = True - else: - # Further constrain based on what's actually available by backend - max_hit_size_tokens = offloaded_block_size * ( - start_block_idx + block_hits - ) + self._touch(req_status) - num_hit_tokens = max_hit_size_tokens - num_computed_tokens - if num_hit_tokens < offloaded_block_size: - # we can only load less than a block, better skip - return 0, False - - if ( - block_hits - and self._blocks_being_loaded - and any( - key in self._blocks_being_loaded - for key in offload_keys[:block_hits] - ) - ): - # hit blocks are being loaded, delay request - delay_request = True - - if defer_lookup: - logger.debug( - "Offloading manager delayed request %s as backend requested", - req_status.req.request_id, - ) - return None, False - - if delay_request: - logger.debug( - "Delaying request %s since some of its blocks are already being loaded", - req_status.req.request_id, - ) - return None, False - - logger.debug( - "Request %s hit %s offloaded tokens after %s GPU hit tokens", - request.request_id, - num_hit_tokens, - num_computed_tokens, - ) - - return num_hit_tokens, True + return num_hit_tokens, bool(num_hit_tokens) def update_state_after_alloc( self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int @@ -375,6 +528,13 @@ def update_state_after_alloc( ) num_pending_gpu_blocks = num_gpu_blocks - num_locally_computed_gpu_blocks + if group_config.sliding_window_size_in_blocks is not None: + assert ( + num_pending_gpu_blocks + <= group_config.sliding_window_size_in_blocks + * self.config.block_size_factor + ) + num_blocks = cdiv(num_cached_tokens, offloaded_block_size) assert len(offload_keys) >= num_blocks if num_pending_gpu_blocks: @@ -513,17 +673,21 @@ def _build_store_jobs( req_status.advance_stored_idx(num_offloadable_tokens) continue - for group_state in req_status.group_states: - self.manager.touch(group_state.offload_keys) + self._touch(req_status) keys_to_store = set(store_output.keys_to_store) group_sizes: list[int] = [] block_indices: list[int] = [] src_block_ids: list[int] = [] + sliding_window_block_ids: list[int] = [] + non_sliding_window_block_ids: list[int] = [] for group_config, group_state in zip( self.config.kv_group_configs, req_status.group_states ): + is_sliding_window = ( + group_config.sliding_window_size_in_blocks is not None + ) num_blocks = num_offloadable_tokens // group_config.offloaded_block_size start_block_idx = group_state.next_stored_block_idx block_ids = group_state.block_ids @@ -547,6 +711,11 @@ def _build_store_jobs( elif start_gpu_block_idx is None: start_gpu_block_idx = gpu_block_idx + i src_block_ids.append(block_id) + if is_sliding_window: + sliding_window_block_ids.append(block_id) + else: + non_sliding_window_block_ids.append(block_id) + group_sizes.append(num_group_blocks) block_indices.append(start_gpu_block_idx or 0) group_state.next_stored_block_idx = num_blocks @@ -562,12 +731,21 @@ def _build_store_jobs( any_jid = next(iter(req_status.transfer_jobs)) assert self._jobs[any_jid].is_store req_status.transfer_jobs.add(job_id) + + # Watch sliding window blocks as they may get evicted + # before the request finishes + for bid in sliding_window_block_ids or (): + self._block_id_to_pending_jobs.setdefault(bid, set()).add(job_id) + + # the non-sliding window blocks will be watched only + # when the request finishes self._jobs[job_id] = TransferJobStatus( req_id=req_id, pending_count=self.config.num_workers, keys=set(keys_to_store), is_store=True, - gpu_block_ids=src_block_ids, + non_sliding_window_block_ids=non_sliding_window_block_ids, + sliding_window_block_ids=sliding_window_block_ids or None, ) store_jobs[job_id] = TransferJob( @@ -632,12 +810,16 @@ def update_connector_output(self, connector_output: KVConnectorOutput): self._blocks_being_loaded.difference_update(job_status.keys) req_status = self._req_status[job_status.req_id] - if self._block_id_to_pending_jobs and req_status.req.is_finished(): - for bid in job_status.gpu_block_ids or (): - pending = self._block_id_to_pending_jobs[bid] - pending.remove(job_id) - if not pending: - del self._block_id_to_pending_jobs[bid] + if self._block_id_to_pending_jobs: + # Sliding window blocks are tracked from store creation + # and must be cleaned up unconditionally. + self._remove_pending_job(job_id, job_status.sliding_window_block_ids) + # Non-sliding-window blocks are only tracked after + # request_finished, so only clean up for finished requests. + if req_status.req.is_finished(): + self._remove_pending_job( + job_id, job_status.non_sliding_window_block_ids + ) del self._jobs[job_id] req_status.transfer_jobs.remove(job_id) @@ -671,7 +853,7 @@ def request_finished( # Register them so future block reuse triggers a flush. for job_id in req_status.transfer_jobs: job_status = self._jobs[job_id] - for bid in job_status.gpu_block_ids or (): + for bid in job_status.non_sliding_window_block_ids or (): self._block_id_to_pending_jobs.setdefault(bid, set()).add(job_id) return False, None