Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@
VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = ""
VLLM_DEBUG_WORKSPACE: bool = False
VLLM_PIN_SWA_TOKENS: bool = False
VLLM_PIN_MIN_DROP_SIZE: int = 16
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD: int = 1024
Expand Down Expand Up @@ -1858,6 +1860,17 @@ def _resolve_rust_frontend_path() -> str | None:
# Debug workspace allocations.
# logging of workspace resize operations.
"VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))),
# On/off switch (true/false) for sliding-window KV block pinning. When
# enabled, each SWA drop pins the current sliding window of KV blocks --
# the freshest cached blocks and the contiguous anchor a future request
# needs to hit the SWA prefix cache -- so they are evicted last. Disabled
# by default; all out-of-window blocks then free normally.
"VLLM_PIN_SWA_TOKENS": lambda: os.getenv("VLLM_PIN_SWA_TOKENS", "0").lower()
in ("1", "true"),
# Minimum drop size (in blocks) required to activate pinning.
# Decode-step drops (usually 1 block) skip pinning to avoid bloating the
# pinned set with unique-tail hashes that provide no prefix-match value.
"VLLM_PIN_MIN_DROP_SIZE": lambda: int(os.getenv("VLLM_PIN_MIN_DROP_SIZE", "16")),
# Disables parallel execution of shared_experts via separate cuda stream
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
Expand Down
50 changes: 45 additions & 5 deletions vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Iterable, Sequence
from typing import Any

import vllm.envs as envs
from vllm.distributed.kv_events import (
MEDIUM_GPU,
AllBlocksCleared,
Expand Down Expand Up @@ -181,6 +182,10 @@ def __init__(

self.metrics_collector = metrics_collector

# For Sliding Window block when VLLM_PIN_SWA_TOKENS enabled.
# The queue where the pinned blocks are stored
self.pinned_block_queue: FreeKVCacheBlockQueue = FreeKVCacheBlockQueue([])

def get_cached_block(
self, block_hash: BlockHash, kv_cache_group_ids: list[int]
) -> list[KVCacheBlock] | None:
Expand Down Expand Up @@ -352,12 +357,14 @@ def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
self._maybe_evict_cached_block(block)
assert block.ref_cnt == 0
block.ref_cnt += 1
block.is_pinned = False
if self.metrics_collector:
self.metrics_collector.on_block_allocated(block)
else:
for block in ret:
assert block.ref_cnt == 0
block.ref_cnt += 1
block.is_pinned = False
if self.metrics_collector:
self.metrics_collector.on_block_allocated(block)
return ret
Expand Down Expand Up @@ -411,11 +418,35 @@ def touch(self, blocks: Sequence[KVCacheBlock]) -> None:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
# block is only pinned when it belongs to Sliding Window
# attention with VLLM_PIN_SWA_TOKENS enabled.
if block.is_pinned:
self.pinned_block_queue.remove(block)
else:
self.free_block_queue.remove(block)
block.ref_cnt += 1
if self.metrics_collector:
self.metrics_collector.on_block_accessed(block)

def demote_n(self, n: int) -> int:
"""Only used for Sliding Window attention with VLLM_PIN_SWA_TOKENS enabled.

When free blocks are needed but blocks are pinned, move up to n oldest
pinned blocks to the normal free queue by flipping is_pinned=False.
Their hashes survive until _maybe_evict_cached_block fires on physical
reuse. Returns the number actually demoted.
"""
if n <= 0:
return 0
num_to_demote = min(n, self.pinned_block_queue.num_free_blocks)
if num_to_demote <= 0:
return 0
blocks = self.pinned_block_queue.popleft_n(num_to_demote)
for block in blocks:
block.is_pinned = False
self.free_block_queue.append_n(blocks)
return num_to_demote

def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""Free a list of blocks. The blocks should be ordered by their
eviction priority, where the first block will be evicted first.
Expand All @@ -424,13 +455,22 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
ordered_blocks: A list of blocks to free ordered by their eviction
priority.
"""
# Materialize the iterable to allow multiple passes.
blocks_list = list(ordered_blocks)
for block in blocks_list:
block.ref_cnt -= 1
self.free_block_queue.append_n(
[block for block in blocks_list if block.ref_cnt == 0 and not block.is_null]
)

freed = [b for b in blocks_list if b.ref_cnt == 0 and not b.is_null]
if not envs.VLLM_PIN_SWA_TOKENS:
self.free_block_queue.append_n(freed)
else:
# Pinning enabled: route freed blocks to the regular vs pinned tier by
# is_pinned (pinned blocks are released later via demote_n).
regular_free = [b for b in freed if not b.is_pinned]
pinned_free = [b for b in freed if b.is_pinned]
if regular_free:
self.free_block_queue.append_n(regular_free)
if pinned_free:
self.pinned_block_queue.append_n(pinned_free)

def evict_blocks(self, block_ids: set[int]) -> None:
"""evict blocks from the prefix cache by their block IDs.
Expand Down
57 changes: 54 additions & 3 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass
from typing import Literal, overload

import vllm.envs as envs
from vllm.distributed.kv_events import BlockStored, KVCacheEvent
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
Expand Down Expand Up @@ -161,6 +162,28 @@ def __init__(
for group in kv_cache_config.kv_cache_groups
)

# Surface a startup hint describing what prefix-cache pinning does,
# so operators can confirm the feature is active and how to tune it.
if envs.VLLM_PIN_SWA_TOKENS:
# Reuse the per-group (kind, sliding_window) metadata computed above
# to report the SWA window size(s) without re-walking the groups.
swa_windows = sorted({sw for _, sw in self.kv_cache_event_metadata if sw})
swa_window_str = (
"/".join(str(w) for w in swa_windows) if swa_windows else "none"
)
logger.info(
"Sliding Window KV block pinning (VLLM_PIN_SWA_TOKENS) is now "
"ENABLED. SWA layers will PIN the last sliding window block "
"(%s tokens) at a higher priority than the rest of the blocks "
"in the chunk, so all the last windows are evicted last. This "
"frees up KV cache pools by keeping only the KV blocks "
"holding the last window, but are able to reuse the entire "
"chunk through sliding window attention. The reuse rate "
"increases for some traffic by holding more active KV blocks "
"on the server. Set VLLM_PIN_SWA_TOKENS=false to disable.",
swa_window_str,
)

# Pre-constructed KVCacheBlocks with no blocks, callers should use this
# via create_kv_cache_blocks instead of creating new ones to avoid GC
# overhead.
Expand Down Expand Up @@ -356,7 +379,17 @@ def allocate_slots(
num_tokens_main_model=full_num_tokens,
apply_admission_cap=True,
)
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
num_free_blocks = self.block_pool.get_num_free_blocks()
if envs.VLLM_PIN_SWA_TOKENS and num_blocks_to_allocate > num_free_blocks:
# Under pressure: demote oldest pinned blocks to make room.
# The full_sequence_must_fit admission gate otherwise
# rejects without giving the pinned tier a chance to release,
# deadlocking once pinned blocks fill the pool. demote_n is
# best-effort, so re-check free space afterwards.
deficit = num_blocks_to_allocate - num_free_blocks
self.block_pool.demote_n(deficit)
num_free_blocks = self.block_pool.get_num_free_blocks()
if num_blocks_to_allocate > num_free_blocks:
return None

num_tokens_main_model = total_computed_tokens + num_new_tokens
Expand Down Expand Up @@ -384,8 +417,16 @@ def allocate_slots(
num_tokens_main_model=num_tokens_main_model,
)

if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
# Cannot allocate new blocks
num_free_blocks = self.block_pool.get_num_free_blocks()
if envs.VLLM_PIN_SWA_TOKENS and num_blocks_to_allocate > num_free_blocks:
# Under pressure: demote oldest pinned blocks to make room.
# Hashes survive until physically recycled, so demoted blocks
# remain prefix-cache candidates. demote_n is best-effort, so
# re-check free space afterwards.
deficit = num_blocks_to_allocate - num_free_blocks
self.block_pool.demote_n(deficit)
num_free_blocks = self.block_pool.get_num_free_blocks()
if num_blocks_to_allocate > num_free_blocks:
return None

if (
Expand Down Expand Up @@ -434,6 +475,16 @@ def free(self, request: Request) -> None:
Args:
request: The request to free the blocks.
"""
# is_pinned design: when prefix-cache pinning is enabled, mark
# all of this request's remaining (non-null) blocks as pinned so
# they land in the pinned_block_queue instead of the regular free
# queue. Full-attention blocks protect the full prefix; SWA
# window blocks protect the last-window hashes.
if envs.VLLM_PIN_SWA_TOKENS:
for mgr in self.coordinator.single_type_managers:
for b in mgr.req_to_blocks.get(request.request_id, ()):
if not b.is_null:
b.is_pinned = True
self.coordinator.free(request.request_id)

def remove_skipped_blocks(
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ class KVCacheBlock:
# Whether the block is a null block that should never be cached.
is_null: bool = False

# Whether the block is pinned as a prefix-cache retention candidate.
# Pinned blocks at ref_cnt=0 live in BlockPool.pinned_block_queue
# instead of the normal free queue; they are only reissued by
# get_new_blocks after being demoted via pressure release.
is_pinned: bool = False

@property
def block_hash(self) -> BlockHashWithGroupId | None:
return self._block_hash
Expand Down
60 changes: 60 additions & 0 deletions vllm/v1/core/single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict
from collections.abc import Sequence

import vllm.envs as envs
from vllm.utils.math_utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (
Expand Down Expand Up @@ -649,6 +650,12 @@ def find_longest_cache_hit(
def _cache_block_mask(
self, num_cached_blocks: int, num_full_blocks: int, alignment_tokens: int
) -> list[bool] | None:
# When SWA-pinning is enabled (PR #40676), the pinned older blocks need
# to be in the prefix-cache hash map so future requests can hit them.
# The default mask skips them, defeating the pinning. Cache everything
# so the SWA-pin path can do its job.
if envs.VLLM_PIN_SWA_TOKENS:
return None
assert alignment_tokens > self.block_size
per_segment = alignment_tokens // self.block_size
tail = cdiv(self.sliding_window - 1, self.block_size)
Expand Down Expand Up @@ -687,6 +694,59 @@ def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
return max(0, num_computed_tokens - self.sliding_window + 1)

def remove_skipped_blocks(
self, request_id: str, total_computed_tokens: int
) -> None:
"""Sliding-window block release with prefix-cache pinning.

As the window advances, the oldest in-window blocks fall out of
window. With VLLM_PIN_SWA_TOKENS enabled, the CURRENT sliding
window blocks are PINNED -- they are the freshest cached blocks
and the exact contiguous run a future request needs to anchor an
SWA prefix-cache hit at this chunk boundary. The window blocks stay
live this step; marking is_pinned routes them to the pinned tier
when they later fall out of window. All out-of-window blocks are
freed -- any previously-pinned window among them survives via its
is_pinned flag, which free_blocks routes to the pinned tier.
Everything else falls back to the base free-all path.
"""
if not envs.VLLM_PIN_SWA_TOKENS:
return super().remove_skipped_blocks(request_id, total_computed_tokens)

num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)
if num_skipped_tokens <= 0:
return
blocks = self.req_to_blocks[request_id]
num_skipped_blocks = min(num_skipped_tokens // self.block_size, len(blocks))

# Small decode-step drops carry unique-tail hashes (no reuse value).
num_new_drops = 0
for j in range(num_skipped_blocks - 1, -1, -1):
if blocks[j] == self._null_block:
break
num_new_drops += 1
if num_new_drops < envs.VLLM_PIN_MIN_DROP_SIZE:
return super().remove_skipped_blocks(request_id, total_computed_tokens)

# Pin the current sliding window (the freshest cached anchor). The
# blocks stay live now; is_pinned takes effect when they later drop.
window_blocks = cdiv(self.sliding_window, self.block_size)
win_end = min(num_skipped_blocks + window_blocks, len(blocks))
for i in range(num_skipped_blocks, win_end):
if blocks[i] != self._null_block:
blocks[i].is_pinned = True

# Free ALL out-of-window blocks. Any previously-pinned window blocks
# among them are routed to the pinned tier by free_blocks.
to_free: list[KVCacheBlock] = []
for i in range(num_skipped_blocks - 1, -1, -1):
if blocks[i] == self._null_block:
break
to_free.append(blocks[i])
blocks[i] = self._null_block
if to_free:
self.block_pool.free_blocks(to_free)

def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
Expand Down
Loading