diff --git a/vllm/envs.py b/vllm/envs.py index c12e3cae247f..2e6cfc085223 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -254,6 +254,8 @@ VLLM_USE_FBGEMM: bool = False VLLM_GC_DEBUG: str = "" VLLM_DEBUG_WORKSPACE: bool = False + VLLM_PIN_SWA_TOKENS: bool = False + VLLM_PIN_MIN_DROP_SIZE: int = 16 VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD: int = 1024 @@ -1858,6 +1860,17 @@ def _resolve_rust_frontend_path() -> str | None: # Debug workspace allocations. # logging of workspace resize operations. "VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))), + # On/off switch (true/false) for sliding-window KV block pinning. When + # enabled, each SWA drop pins the current sliding window of KV blocks -- + # the freshest cached blocks and the contiguous anchor a future request + # needs to hit the SWA prefix cache -- so they are evicted last. Disabled + # by default; all out-of-window blocks then free normally. + "VLLM_PIN_SWA_TOKENS": lambda: os.getenv("VLLM_PIN_SWA_TOKENS", "0").lower() + in ("1", "true"), + # Minimum drop size (in blocks) required to activate pinning. + # Decode-step drops (usually 1 block) skip pinning to avoid bloating the + # pinned set with unique-tail hashes that provide no prefix-match value. + "VLLM_PIN_MIN_DROP_SIZE": lambda: int(os.getenv("VLLM_PIN_MIN_DROP_SIZE", "16")), # Disables parallel execution of shared_experts via separate cuda stream "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool( int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0")) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 513e4bf380b9..df0078ccdfc8 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -3,6 +3,7 @@ from collections.abc import Iterable, Sequence from typing import Any +import vllm.envs as envs from vllm.distributed.kv_events import ( MEDIUM_GPU, AllBlocksCleared, @@ -181,6 +182,10 @@ def __init__( self.metrics_collector = metrics_collector + # For Sliding Window block when VLLM_PIN_SWA_TOKENS enabled. + # The queue where the pinned blocks are stored + self.pinned_block_queue: FreeKVCacheBlockQueue = FreeKVCacheBlockQueue([]) + def get_cached_block( self, block_hash: BlockHash, kv_cache_group_ids: list[int] ) -> list[KVCacheBlock] | None: @@ -352,12 +357,14 @@ def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: self._maybe_evict_cached_block(block) assert block.ref_cnt == 0 block.ref_cnt += 1 + block.is_pinned = False if self.metrics_collector: self.metrics_collector.on_block_allocated(block) else: for block in ret: assert block.ref_cnt == 0 block.ref_cnt += 1 + block.is_pinned = False if self.metrics_collector: self.metrics_collector.on_block_allocated(block) return ret @@ -411,11 +418,35 @@ def touch(self, blocks: Sequence[KVCacheBlock]) -> None: # ref_cnt=0 means this block is in the free list (i.e. eviction # candidate), so remove it. if block.ref_cnt == 0 and not block.is_null: - self.free_block_queue.remove(block) + # block is only pinned when it belongs to Sliding Window + # attention with VLLM_PIN_SWA_TOKENS enabled. + if block.is_pinned: + self.pinned_block_queue.remove(block) + else: + self.free_block_queue.remove(block) block.ref_cnt += 1 if self.metrics_collector: self.metrics_collector.on_block_accessed(block) + def demote_n(self, n: int) -> int: + """Only used for Sliding Window attention with VLLM_PIN_SWA_TOKENS enabled. + + When free blocks are needed but blocks are pinned, move up to n oldest + pinned blocks to the normal free queue by flipping is_pinned=False. + Their hashes survive until _maybe_evict_cached_block fires on physical + reuse. Returns the number actually demoted. + """ + if n <= 0: + return 0 + num_to_demote = min(n, self.pinned_block_queue.num_free_blocks) + if num_to_demote <= 0: + return 0 + blocks = self.pinned_block_queue.popleft_n(num_to_demote) + for block in blocks: + block.is_pinned = False + self.free_block_queue.append_n(blocks) + return num_to_demote + def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: """Free a list of blocks. The blocks should be ordered by their eviction priority, where the first block will be evicted first. @@ -424,13 +455,22 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: ordered_blocks: A list of blocks to free ordered by their eviction priority. """ - # Materialize the iterable to allow multiple passes. blocks_list = list(ordered_blocks) for block in blocks_list: block.ref_cnt -= 1 - self.free_block_queue.append_n( - [block for block in blocks_list if block.ref_cnt == 0 and not block.is_null] - ) + + freed = [b for b in blocks_list if b.ref_cnt == 0 and not b.is_null] + if not envs.VLLM_PIN_SWA_TOKENS: + self.free_block_queue.append_n(freed) + else: + # Pinning enabled: route freed blocks to the regular vs pinned tier by + # is_pinned (pinned blocks are released later via demote_n). + regular_free = [b for b in freed if not b.is_pinned] + pinned_free = [b for b in freed if b.is_pinned] + if regular_free: + self.free_block_queue.append_n(regular_free) + if pinned_free: + self.pinned_block_queue.append_n(pinned_free) def evict_blocks(self, block_ids: set[int]) -> None: """evict blocks from the prefix cache by their block IDs. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 9359d8843a91..a73c2bc1e965 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Literal, overload +import vllm.envs as envs from vllm.distributed.kv_events import BlockStored, KVCacheEvent from vllm.logger import init_logger from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator @@ -161,6 +162,28 @@ def __init__( for group in kv_cache_config.kv_cache_groups ) + # Surface a startup hint describing what prefix-cache pinning does, + # so operators can confirm the feature is active and how to tune it. + if envs.VLLM_PIN_SWA_TOKENS: + # Reuse the per-group (kind, sliding_window) metadata computed above + # to report the SWA window size(s) without re-walking the groups. + swa_windows = sorted({sw for _, sw in self.kv_cache_event_metadata if sw}) + swa_window_str = ( + "/".join(str(w) for w in swa_windows) if swa_windows else "none" + ) + logger.info( + "Sliding Window KV block pinning (VLLM_PIN_SWA_TOKENS) is now " + "ENABLED. SWA layers will PIN the last sliding window block " + "(%s tokens) at a higher priority than the rest of the blocks " + "in the chunk, so all the last windows are evicted last. This " + "frees up KV cache pools by keeping only the KV blocks " + "holding the last window, but are able to reuse the entire " + "chunk through sliding window attention. The reuse rate " + "increases for some traffic by holding more active KV blocks " + "on the server. Set VLLM_PIN_SWA_TOKENS=false to disable.", + swa_window_str, + ) + # Pre-constructed KVCacheBlocks with no blocks, callers should use this # via create_kv_cache_blocks instead of creating new ones to avoid GC # overhead. @@ -356,7 +379,17 @@ def allocate_slots( num_tokens_main_model=full_num_tokens, apply_admission_cap=True, ) - if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + num_free_blocks = self.block_pool.get_num_free_blocks() + if envs.VLLM_PIN_SWA_TOKENS and num_blocks_to_allocate > num_free_blocks: + # Under pressure: demote oldest pinned blocks to make room. + # The full_sequence_must_fit admission gate otherwise + # rejects without giving the pinned tier a chance to release, + # deadlocking once pinned blocks fill the pool. demote_n is + # best-effort, so re-check free space afterwards. + deficit = num_blocks_to_allocate - num_free_blocks + self.block_pool.demote_n(deficit) + num_free_blocks = self.block_pool.get_num_free_blocks() + if num_blocks_to_allocate > num_free_blocks: return None num_tokens_main_model = total_computed_tokens + num_new_tokens @@ -384,8 +417,16 @@ def allocate_slots( num_tokens_main_model=num_tokens_main_model, ) - if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): - # Cannot allocate new blocks + num_free_blocks = self.block_pool.get_num_free_blocks() + if envs.VLLM_PIN_SWA_TOKENS and num_blocks_to_allocate > num_free_blocks: + # Under pressure: demote oldest pinned blocks to make room. + # Hashes survive until physically recycled, so demoted blocks + # remain prefix-cache candidates. demote_n is best-effort, so + # re-check free space afterwards. + deficit = num_blocks_to_allocate - num_free_blocks + self.block_pool.demote_n(deficit) + num_free_blocks = self.block_pool.get_num_free_blocks() + if num_blocks_to_allocate > num_free_blocks: return None if ( @@ -434,6 +475,16 @@ def free(self, request: Request) -> None: Args: request: The request to free the blocks. """ + # is_pinned design: when prefix-cache pinning is enabled, mark + # all of this request's remaining (non-null) blocks as pinned so + # they land in the pinned_block_queue instead of the regular free + # queue. Full-attention blocks protect the full prefix; SWA + # window blocks protect the last-window hashes. + if envs.VLLM_PIN_SWA_TOKENS: + for mgr in self.coordinator.single_type_managers: + for b in mgr.req_to_blocks.get(request.request_id, ()): + if not b.is_null: + b.is_pinned = True self.coordinator.free(request.request_id) def remove_skipped_blocks( diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 7f3a5e4fdf3f..ac13cb7c7b1f 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -132,6 +132,12 @@ class KVCacheBlock: # Whether the block is a null block that should never be cached. is_null: bool = False + # Whether the block is pinned as a prefix-cache retention candidate. + # Pinned blocks at ref_cnt=0 live in BlockPool.pinned_block_queue + # instead of the normal free queue; they are only reissued by + # get_new_blocks after being demoted via pressure release. + is_pinned: bool = False + @property def block_hash(self) -> BlockHashWithGroupId | None: return self._block_hash diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index cd000dc849e7..6492ac458f68 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -5,6 +5,7 @@ from collections import defaultdict from collections.abc import Sequence +import vllm.envs as envs from vllm.utils.math_utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import ( @@ -649,6 +650,12 @@ def find_longest_cache_hit( def _cache_block_mask( self, num_cached_blocks: int, num_full_blocks: int, alignment_tokens: int ) -> list[bool] | None: + # When SWA-pinning is enabled (PR #40676), the pinned older blocks need + # to be in the prefix-cache hash map so future requests can hit them. + # The default mask skips them, defeating the pinning. Cache everything + # so the SWA-pin path can do its job. + if envs.VLLM_PIN_SWA_TOKENS: + return None assert alignment_tokens > self.block_size per_segment = alignment_tokens // self.block_size tail = cdiv(self.sliding_window - 1, self.block_size) @@ -687,6 +694,59 @@ def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: """ return max(0, num_computed_tokens - self.sliding_window + 1) + def remove_skipped_blocks( + self, request_id: str, total_computed_tokens: int + ) -> None: + """Sliding-window block release with prefix-cache pinning. + + As the window advances, the oldest in-window blocks fall out of + window. With VLLM_PIN_SWA_TOKENS enabled, the CURRENT sliding + window blocks are PINNED -- they are the freshest cached blocks + and the exact contiguous run a future request needs to anchor an + SWA prefix-cache hit at this chunk boundary. The window blocks stay + live this step; marking is_pinned routes them to the pinned tier + when they later fall out of window. All out-of-window blocks are + freed -- any previously-pinned window among them survives via its + is_pinned flag, which free_blocks routes to the pinned tier. + Everything else falls back to the base free-all path. + """ + if not envs.VLLM_PIN_SWA_TOKENS: + return super().remove_skipped_blocks(request_id, total_computed_tokens) + + num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens) + if num_skipped_tokens <= 0: + return + blocks = self.req_to_blocks[request_id] + num_skipped_blocks = min(num_skipped_tokens // self.block_size, len(blocks)) + + # Small decode-step drops carry unique-tail hashes (no reuse value). + num_new_drops = 0 + for j in range(num_skipped_blocks - 1, -1, -1): + if blocks[j] == self._null_block: + break + num_new_drops += 1 + if num_new_drops < envs.VLLM_PIN_MIN_DROP_SIZE: + return super().remove_skipped_blocks(request_id, total_computed_tokens) + + # Pin the current sliding window (the freshest cached anchor). The + # blocks stay live now; is_pinned takes effect when they later drop. + window_blocks = cdiv(self.sliding_window, self.block_size) + win_end = min(num_skipped_blocks + window_blocks, len(blocks)) + for i in range(num_skipped_blocks, win_end): + if blocks[i] != self._null_block: + blocks[i].is_pinned = True + + # Free ALL out-of-window blocks. Any previously-pinned window blocks + # among them are routed to the pinned tier by free_blocks. + to_free: list[KVCacheBlock] = [] + for i in range(num_skipped_blocks - 1, -1, -1): + if blocks[i] == self._null_block: + break + to_free.append(blocks[i]) + blocks[i] = self._null_block + if to_free: + self.block_pool.free_blocks(to_free) + def get_num_common_prefix_blocks(self, running_request_id: str) -> int: """ NOTE(Chen): The prefix blocks are null blocks for sliding window layers.