Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2513,3 +2513,111 @@ def test_block_lookup_cache_multi_blocks_per_key():
assert cache.pop(key1, 11) is block11
assert cache.get_one_block(key1) is None
assert cache.pop(key1, 12) is None


def test_can_fit_full_sequence_swa_cap_admits_long_prompt():
"""Hybrid full+SWA model with a pool sized at the startup minimum should
admit a prompt longer than the SWA cap, because SlidingWindowManager
recycles blocks during chunked prefill (issue #39734)."""
block_size = 16
sliding_window = 4 * block_size # 64 tokens
max_num_batched_tokens = 8 * block_size # 128 tokens
max_model_len = 64 * block_size # 1024 tokens — much larger than the SWA cap
# Startup pool sizing: full demands cdiv(max_model_len, bs) = 64 blocks,
# SWA demands cdiv(SW-1+max_batched, bs) + 1 = cdiv(191, 16) + 1 = 13.
# Pool minimum = 64 + 13 = 77; +1 for the null block.
num_blocks = 64 + 13 + 1

config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer_full"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
),
KVCacheGroupSpec(
["layer_swa"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window,
),
),
],
)

manager = KVCacheManager(
config,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
enable_caching=True,
hash_block_size=block_size,
)

# A prompt that is shorter than max_model_len but longer than SW + chunk:
# cdiv(prompt_len, bs) = 32 blocks. Without the cap, admission would
# demand 32 (full) + 32 (SWA) = 64 blocks. With the cap, SWA contributes
# only 13, so total = 32 + 13 = 45 ≤ pool size.
prompt_len = 32 * block_size
req = make_request("long", list(range(prompt_len)), block_size, sha256)

assert manager.can_fit_full_sequence(req)


def test_can_fit_full_sequence_full_attention_still_gates_oversized():
"""The cap only loosens the SWA group; a prompt that exceeds the
full-attention pool capacity must still be rejected."""
block_size = 16
sliding_window = 4 * block_size
max_num_batched_tokens = 8 * block_size
max_model_len = 64 * block_size
# Provide a tiny pool — even a small prompt should be rejected.
num_blocks = 5

config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer_full"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
),
KVCacheGroupSpec(
["layer_swa"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window,
),
),
],
)

manager = KVCacheManager(
config,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
enable_caching=True,
hash_block_size=block_size,
)

# 16 blocks of full attention demand alone exceeds the 5-block pool.
prompt_len = 16 * block_size
req = make_request("oversized", list(range(prompt_len)), block_size, sha256)

assert not manager.can_fit_full_sequence(req)
3 changes: 3 additions & 0 deletions tests/v1/core/test_single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@


def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True):
# Tests don't exercise admission gating; pass a large cap that is a no-op.
return SlidingWindowManager(
sliding_window_spec,
block_pool=block_pool,
enable_caching=enable_caching,
kv_cache_group_id=0,
max_admission_blocks_per_request=10**9,
)


Expand All @@ -38,6 +40,7 @@ def get_chunked_local_attention_manager(
block_pool=block_pool,
enable_caching=enable_caching,
kv_cache_group_id=0,
max_admission_blocks_per_request=10**9,
)


Expand Down
13 changes: 13 additions & 0 deletions vllm/v1/core/kv_cache_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
max_num_batched_tokens: int,
use_eagle: bool,
enable_caching: bool,
enable_kv_cache_events: bool,
Expand All @@ -59,6 +60,8 @@ def __init__(
self.single_type_managers = tuple(
get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_group.kv_cache_spec,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
block_pool=self.block_pool,
enable_caching=enable_caching,
kv_cache_group_id=i,
Expand Down Expand Up @@ -265,6 +268,7 @@ def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
max_num_batched_tokens: int,
use_eagle: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
Expand All @@ -275,6 +279,7 @@ def __init__(
super().__init__(
kv_cache_config,
max_model_len,
max_num_batched_tokens,
use_eagle,
False,
enable_kv_cache_events,
Expand Down Expand Up @@ -310,6 +315,7 @@ def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
max_num_batched_tokens: int,
use_eagle: bool,
enable_caching: bool,
enable_kv_cache_events: bool,
Expand All @@ -321,6 +327,7 @@ def __init__(
super().__init__(
kv_cache_config,
max_model_len,
max_num_batched_tokens,
use_eagle,
enable_caching,
enable_kv_cache_events,
Expand Down Expand Up @@ -375,6 +382,7 @@ def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
max_num_batched_tokens: int,
use_eagle: bool,
enable_caching: bool,
enable_kv_cache_events: bool,
Expand All @@ -386,6 +394,7 @@ def __init__(
super().__init__(
kv_cache_config,
max_model_len,
max_num_batched_tokens,
use_eagle,
enable_caching,
enable_kv_cache_events,
Expand Down Expand Up @@ -547,6 +556,7 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList:
def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig,
max_model_len: int,
max_num_batched_tokens: int,
use_eagle: bool,
enable_caching: bool,
enable_kv_cache_events: bool,
Expand All @@ -559,6 +569,7 @@ def get_kv_cache_coordinator(
return KVCacheCoordinatorNoPrefixCache(
kv_cache_config,
max_model_len,
max_num_batched_tokens,
use_eagle,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
Expand All @@ -570,6 +581,7 @@ def get_kv_cache_coordinator(
return UnitaryKVCacheCoordinator(
kv_cache_config,
max_model_len,
max_num_batched_tokens,
use_eagle,
enable_caching,
enable_kv_cache_events,
Expand All @@ -581,6 +593,7 @@ def get_kv_cache_coordinator(
return HybridKVCacheCoordinator(
kv_cache_config,
max_model_len,
max_num_batched_tokens,
use_eagle,
enable_caching,
enable_kv_cache_events,
Expand Down
7 changes: 7 additions & 0 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
kv_cache_config: KVCacheConfig,
max_model_len: int,
hash_block_size: int,
max_num_batched_tokens: int | None = None,
enable_caching: bool = True,
use_eagle: bool = False,
log_stats: bool = False,
Expand All @@ -118,6 +119,11 @@ def __init__(
metrics_collector: KVCacheMetricsCollector | None = None,
) -> None:
self.max_model_len = max_model_len
# When unset, fall back to `max_model_len` so the recycling-aware cap
# collapses to the prior (uncapped) admission behavior. The scheduler
# always supplies the real value at runtime.
if max_num_batched_tokens is None:
max_num_batched_tokens = max_model_len

self.enable_caching = enable_caching
self.use_eagle = use_eagle
Expand All @@ -131,6 +137,7 @@ def __init__(
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
use_eagle=self.use_eagle,
enable_caching=self.enable_caching,
enable_kv_cache_events=enable_kv_cache_events,
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __init__(
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens,
enable_caching=self.cache_config.enable_prefix_caching,
use_eagle=self.use_eagle,
log_stats=self.log_stats,
Expand Down
36 changes: 35 additions & 1 deletion vllm/v1/core/single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,20 @@ def __init__(
kv_cache_group_id: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
max_admission_blocks_per_request: int | None = None,
) -> None:
"""
Initializes the SingleTypeKVCacheManager.
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
kv_cache_group_id: The id of the kv cache group of this manager.
max_admission_blocks_per_request: Recycling-aware per-request
block cap used by `get_num_blocks_to_allocate`. Only set for
spec types that recycle blocks across chunks (SWA,
chunked-local); `None` (the default) means no cap, which is
correct for full-attention-style specs that hold every
block until the request finishes.
"""
self.block_size = kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
Expand All @@ -56,6 +63,7 @@ def __init__(
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
self.enable_caching = enable_caching
self._max_admission_blocks_per_request = max_admission_blocks_per_request
self.new_block_ids: list[int] = []

# Mapping from request ID to blocks to track the blocks allocated
Expand Down Expand Up @@ -104,6 +112,19 @@ def get_num_blocks_to_allocate(
"""

num_required_blocks = cdiv(num_tokens, self.block_size)
if self._max_admission_blocks_per_request is not None:
# Recycling-aware specs (SWA, chunked-local) cap the per-request
# reservation here so admission matches the startup pool sizer
# (`SlidingWindowSpec.max_admission_blocks_per_request` / its
# chunked-local counterpart). `remove_skipped_blocks` runs from
# `allocate_slots` before each chunk's `get_num_blocks_to_allocate`,
# so per-request peak real-held blocks <= this cap, which keeps
# `sum(reservations) <= pool` <=> `sum(peak_real_held) <= pool`.
# Drift between the two would re-introduce the deadlock from
# issue #39734 or, worse, mid-prefill OOM.
num_required_blocks = min(
num_required_blocks, self._max_admission_blocks_per_request
)
num_req_blocks = len(self.req_to_blocks.get(request_id, ()))

if request_id in self.num_cached_block:
Expand Down Expand Up @@ -1126,8 +1147,21 @@ def __init__(


def get_manager_for_kv_cache_spec(
kv_cache_spec: KVCacheSpec, **kwargs
kv_cache_spec: KVCacheSpec,
max_num_batched_tokens: int,
max_model_len: int,
**kwargs,
) -> SingleTypeKVCacheManager:
manager_class = spec_manager_map[type(kv_cache_spec)]
# SlidingWindow / ChunkedLocalAttention managers recycle blocks across
# chunks; the runtime admission cap must match the recycling-aware bound
# the startup pool sizer uses (single source of truth: the spec method).
if isinstance(kv_cache_spec, (SlidingWindowSpec, ChunkedLocalAttentionSpec)):
kwargs["max_admission_blocks_per_request"] = (
kv_cache_spec.max_admission_blocks_per_request(
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
)
)
manager = manager_class(kv_cache_spec, **kwargs)
return manager
Loading
Loading