Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1831,6 +1831,59 @@ def test_reset_prefix_cache():
assert all([blk.block_hash is None for blk in manager.block_pool.blocks])


def test_hybrid_swa_cap_does_not_crash_allocator():
block_size = 16
sliding_window = 64
num_tokens = 200
request_id = "r"

manager = KVCacheManager(
kv_cache_config=KVCacheConfig(
num_blocks=26, # 25 usable blocks + 1 null block.
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["full"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
),
KVCacheGroupSpec(
["swa"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window,
),
),
],
),
max_model_len=4096,
hash_block_size=block_size,
max_num_batched_tokens=32,
enable_caching=True,
)

req = make_request(request_id, [1] * num_tokens, block_size, sha256)
full_required_blocks = (num_tokens + block_size - 1) // block_size
allocated = manager.block_pool.get_new_blocks(
manager.block_pool.num_gpu_blocks - 1)
full_mgr, swa_mgr = manager.coordinator.single_type_managers

full_mgr.req_to_blocks[request_id] = allocated[:full_required_blocks]
swa_mgr.req_to_blocks[request_id] = allocated[
full_required_blocks:full_required_blocks + (full_required_blocks - 1)
]

assert manager.block_pool.get_num_free_blocks() == 0
assert manager.allocate_slots(req, num_new_tokens=num_tokens) is None


def test_prefix_cache_stats_disabled():
"""Test that prefix_cache_stats is None when log_stats is False."""
block_size = 16
Expand Down
63 changes: 61 additions & 2 deletions vllm/v1/core/kv_cache_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.v1.core.single_type_kv_cache_manager import (
CrossAttentionManager,
SingleTypeKVCacheManager,
SlidingWindowManager,
get_manager_for_kv_cache_spec,
)
from vllm.v1.kv_cache_interface import (
Expand All @@ -40,10 +41,12 @@ def __init__(
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
max_num_batched_tokens: int = 0,
metrics_collector: KVCacheMetricsCollector | None = None,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.enable_caching = enable_caching

self.block_pool = BlockPool(
Expand All @@ -68,7 +71,7 @@ def __init__(
for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups)
)

def get_num_blocks_to_allocate(
def get_num_blocks_needed_for_admission(
self,
request_id: str,
num_tokens: int,
Expand All @@ -78,7 +81,12 @@ def get_num_blocks_to_allocate(
num_tokens_main_model: int,
) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Get the number of blocks needed for admission.

This is an admission estimate, not the exact allocator demand. For SWA
groups, cap the token count at sliding_window + one prefill chunk.
OOW blocks are freed between chunks by remove_skipped_blocks(), so the
scheduler does not need to budget the full sequence length here.

Args:
request_id: The request ID.
Expand All @@ -104,6 +112,36 @@ def get_num_blocks_to_allocate(
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_encoder_tokens, [], 0, num_encoder_tokens
)
else:
effective_num_tokens = self._get_admission_num_tokens(
manager, num_tokens)
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id,
effective_num_tokens,
new_computed_blocks[i],
total_computed_tokens,
num_tokens_main_model,
)
return num_blocks_to_allocate

def get_num_blocks_to_allocate(
self,
request_id: str,
num_tokens: int,
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
num_encoder_tokens: int,
total_computed_tokens: int,
num_tokens_main_model: int,
) -> int:
"""
Get the exact number of blocks needed to allocate for the request.
"""
num_blocks_to_allocate = 0
for i, manager in enumerate(self.single_type_managers):
if isinstance(manager, CrossAttentionManager):
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_encoder_tokens, [], 0, num_encoder_tokens
)
else:
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id,
Expand Down Expand Up @@ -175,6 +213,17 @@ def allocate_new_blocks(
for manager in self.single_type_managers
)

def _get_admission_num_tokens(
self,
manager: SingleTypeKVCacheManager,
num_tokens: int,
) -> int:
# SWA layers never need more than their window plus one prefill chunk.
if isinstance(manager, SlidingWindowManager):
return min(num_tokens,
manager.sliding_window + self.max_num_batched_tokens)
return num_tokens

def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""
Cache the blocks for the request.
Expand Down Expand Up @@ -270,6 +319,7 @@ def __init__(
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
max_num_batched_tokens: int = 0,
metrics_collector: KVCacheMetricsCollector | None = None,
):
super().__init__(
Expand All @@ -281,6 +331,7 @@ def __init__(
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
max_num_batched_tokens=max_num_batched_tokens,
metrics_collector=metrics_collector,
)
self.num_single_type_manager = len(self.single_type_managers)
Expand Down Expand Up @@ -316,6 +367,7 @@ def __init__(
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
max_num_batched_tokens: int = 0,
metrics_collector: KVCacheMetricsCollector | None = None,
):
super().__init__(
Expand All @@ -327,6 +379,7 @@ def __init__(
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
max_num_batched_tokens=max_num_batched_tokens,
metrics_collector=metrics_collector,
)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
Expand Down Expand Up @@ -381,6 +434,7 @@ def __init__(
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
max_num_batched_tokens: int = 0,
metrics_collector: KVCacheMetricsCollector | None = None,
):
super().__init__(
Expand All @@ -392,6 +446,7 @@ def __init__(
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
max_num_batched_tokens=max_num_batched_tokens,
metrics_collector=metrics_collector,
)
# hash_block_size: the block size used to compute block hashes.
Expand Down Expand Up @@ -547,6 +602,7 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList:
def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig,
max_model_len: int,
max_num_batched_tokens: int,
use_eagle: bool,
enable_caching: bool,
enable_kv_cache_events: bool,
Expand All @@ -564,6 +620,7 @@ def get_kv_cache_coordinator(
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
max_num_batched_tokens=max_num_batched_tokens,
metrics_collector=metrics_collector,
)
if len(kv_cache_config.kv_cache_groups) == 1:
Expand All @@ -576,6 +633,7 @@ def get_kv_cache_coordinator(
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
max_num_batched_tokens=max_num_batched_tokens,
metrics_collector=metrics_collector,
)
return HybridKVCacheCoordinator(
Expand All @@ -587,5 +645,6 @@ def get_kv_cache_coordinator(
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
max_num_batched_tokens=max_num_batched_tokens,
metrics_collector=metrics_collector,
)
21 changes: 19 additions & 2 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
kv_cache_config: KVCacheConfig,
max_model_len: int,
hash_block_size: int,
max_num_batched_tokens: int = 0,
enable_caching: bool = True,
use_eagle: bool = False,
log_stats: bool = False,
Expand All @@ -118,6 +119,7 @@ def __init__(
metrics_collector: KVCacheMetricsCollector | None = None,
) -> None:
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens

self.enable_caching = enable_caching
self.use_eagle = use_eagle
Expand All @@ -131,6 +133,7 @@ def __init__(
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_batched_tokens,
use_eagle=self.use_eagle,
enable_caching=self.enable_caching,
enable_kv_cache_events=enable_kv_cache_events,
Expand Down Expand Up @@ -374,7 +377,7 @@ def allocate_slots(
request.request_id, total_computed_tokens
)

num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
block_allocation_kwargs = dict(
request_id=request.request_id,
num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list,
Expand All @@ -384,10 +387,24 @@ def allocate_slots(
num_tokens_main_model=num_tokens_main_model,
)

if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
num_blocks_to_allocate = (
self.coordinator.get_num_blocks_needed_for_admission(
**block_allocation_kwargs))

num_free_blocks = self.block_pool.get_num_free_blocks()
if num_blocks_to_allocate > num_free_blocks:
# Cannot allocate new blocks
return None

# The SWA-capped admission estimate can be lower than the actual
# allocator demand when a request already owns some blocks. Re-check
# with the uncapped token count so we fail cleanly instead of throwing
# out of BlockPool.get_new_blocks().
actual_num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
**block_allocation_kwargs)
if actual_num_blocks_to_allocate > num_free_blocks:
return None

if (
new_computed_block_list is not self.empty_kv_cache_blocks.blocks
or num_external_computed_tokens > 0
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __init__(
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens,
enable_caching=self.cache_config.enable_prefix_caching,
use_eagle=self.use_eagle,
log_stats=self.log_stats,
Expand Down
Loading