Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
066eb9f
Squashed merge PR #23624
ivanium Dec 2, 2025
ef59419
feat: for sliding window attention, only allocate tokens within the w…
ivanium Dec 2, 2025
0e8c625
fix: skip outside sliding window tokens when touch and save cached bl…
ivanium Dec 4, 2025
0870043
fix: make interfaces consistent and remove debug prints
ivanium Dec 4, 2025
771f1d9
nits: remove test scripts
ivanium Dec 4, 2025
7cb3378
fix: revert `cache_block()` changes as we have already handled the nu…
ivanium Dec 4, 2025
9b1b1b6
fix: revert KVCacheManager.allocate_slots() interface changes; revisi…
ivanium Dec 5, 2025
cebae0b
revert unrelated changes
ivanium Dec 5, 2025
49d133e
revert `blocks_to_touch` changes
ivanium Dec 5, 2025
b82b147
fix: update test cases
ivanium Dec 6, 2025
337d918
doc string nits
ivanium Dec 6, 2025
9ac94fb
ignore mypy errors
ivanium Dec 6, 2025
8835ccf
fix: resolve comments; mainly merge local_computed_tokens and externa…
ivanium Dec 13, 2025
400d807
fix: simplify return values of get_num_blocks_to_allocate
ivanium Dec 13, 2025
8367614
test: update test cases
ivanium Dec 13, 2025
21d32dc
fix: num_new_tokens can be 0 when load_kv_async is enabled
ivanium Dec 13, 2025
f889221
fix: revert changes to factory.py
ivanium Dec 14, 2025
3c90d57
nits
ivanium Dec 15, 2025
6cf1788
workaround lmcache new interfaces
ivanium Dec 15, 2025
77cf5ff
fix: avoid memory leak in remove_skipped_blocks; workaround gemma3 pr…
ivanium Dec 16, 2025
30e5673
nits: revise function name and comments
ivanium Dec 16, 2025
244b993
nits
ivanium Dec 18, 2025
c69caaa
fix: remove skipped blocks before passing them to the connector when …
ivanium Dec 19, 2025
cb05716
fix: should use total_computed_tokens for get_num_skipped_tokens()
ivanium Dec 22, 2025
867f1fd
perf: fast path for decode reqs in get_num_blocks_to_allocate()
ivanium Dec 22, 2025
c03fc57
fix: rename stale func names
ivanium Dec 22, 2025
7727c47
various minor fixes
ivanium Dec 22, 2025
4e8381e
refactor: simplify get_num_blocks_to_allocate
ivanium Dec 22, 2025
4afafba
nits
ivanium Dec 23, 2025
6673b11
nits
ivanium Dec 24, 2025
7c9f329
nits
ivanium Dec 25, 2025
b917624
chore: clean up debug code
ivanium Dec 25, 2025
d79e598
nits
ivanium Dec 25, 2025
006d49c
fix: num_required_blocks can be smaller than num_req_blocks in spec d…
ivanium Dec 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 62 additions & 8 deletions tests/v1/core/test_single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,23 @@
pytestmark = pytest.mark.cpu_test


def get_sliding_window_manager(sliding_window_spec, block_pool):
return SlidingWindowManager(sliding_window_spec, block_pool, kv_cache_group_id=0)
def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True):
return SlidingWindowManager(
sliding_window_spec,
block_pool,
enable_caching=enable_caching,
kv_cache_group_id=0,
)


def get_chunked_local_attention_manager(chunked_local_attention_spec, block_pool):
def get_chunked_local_attention_manager(
chunked_local_attention_spec, block_pool, enable_caching=True
):
return ChunkedLocalAttentionManager(
chunked_local_attention_spec, block_pool, kv_cache_group_id=0
chunked_local_attention_spec,
block_pool,
enable_caching=enable_caching,
kv_cache_group_id=0,
)


Expand Down Expand Up @@ -332,11 +342,53 @@ def test_get_num_blocks_to_allocate():
]

assert (
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0)
== 20
)
assert (
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0)
== 15
)


def test_evictable_cached_blocks_not_double_allocated():
block_size = 2
sliding_window_length = 2 * block_size
sliding_window_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window_length,
)

block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)

request_id = "req"
evictable_block = block_pool.blocks[1] # ref_cnt == 0, eviction candidate

num_blocks_to_allocate = manager.get_num_blocks_to_allocate(
request_id=request_id,
num_tokens=2 * block_size,
new_computed_blocks=[evictable_block],
total_computed_tokens=block_size,
)
# Free capacity check should count evictable cached blocks, but allocation
# should only allocate the truly new block.
assert num_blocks_to_allocate == 2

manager.allocate_new_computed_blocks(
request_id,
[evictable_block],
num_local_computed_tokens=block_size,
num_external_computed_tokens=0,
)
new_blocks = manager.allocate_new_blocks(request_id, num_tokens=4)
assert len(new_blocks) == 1
assert len(manager.req_to_blocks[request_id]) == 2


def test_chunked_local_attention_get_num_blocks_to_allocate():
Expand All @@ -359,8 +411,10 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
]

assert (
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0)
== 20
)
assert (
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0)
== 15
)
23 changes: 13 additions & 10 deletions vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ def cache_full_blocks(
[] if self.enable_kv_cache_events else None
)
for i, blk in enumerate(new_full_blocks):
# Some blocks may be null blocks when enabling sparse attention like
# sliding window attention. We skip null blocks here.
if blk.is_null:
continue
assert blk.block_hash is None
block_hash = new_block_hashes[i]

Expand Down Expand Up @@ -361,23 +365,22 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
)
return True

def touch(self, blocks: tuple[Sequence[KVCacheBlock], ...]) -> None:
def touch(self, blocks: Sequence[KVCacheBlock]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.

Args:
blocks: A list of blocks to touch.
"""
for blocks_per_group in blocks:
for block in blocks_per_group:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.ref_cnt += 1
if self.metrics_collector:
self.metrics_collector.on_block_accessed(block)
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.ref_cnt += 1
if self.metrics_collector:
self.metrics_collector.on_block_accessed(block)

def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""Free a list of blocks. The blocks should be ordered by their
Expand Down
46 changes: 35 additions & 11 deletions vllm/v1/core/kv_cache_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool,
enable_caching=enable_caching,
kv_cache_group_id=i,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
Expand All @@ -73,6 +74,7 @@ def get_num_blocks_to_allocate(
num_tokens: int,
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
num_encoder_tokens: int,
total_computed_tokens: int,
) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Expand All @@ -85,40 +87,59 @@ def get_num_blocks_to_allocate(
prefix caching.
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
total_computed_tokens: Include both local and external tokens.

Returns:
The number of blocks.
The number of blocks to allocate.
"""
num_blocks_to_allocate = 0
for i, manager in enumerate(self.single_type_managers):
if isinstance(manager, CrossAttentionManager):
# For cross-attention, we issue a single static allocation
# of blocks based on the number of encoder input tokens.
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_encoder_tokens, []
request_id, num_encoder_tokens, [], 0
)
else:
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks[i]
request_id,
num_tokens,
new_computed_blocks[i],
total_computed_tokens,
)
return num_blocks_to_allocate

def save_new_computed_blocks(
self, request_id: str, new_computed_blocks: tuple[Sequence[KVCacheBlock], ...]
def allocate_new_computed_blocks(
self,
request_id: str,
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
num_local_computed_tokens: int,
num_external_computed_tokens: int,
) -> None:
"""
Add the new computed blocks to the request.
Add the new computed blocks to the request. Optionally allocate new
blocks for external computed tokens (if any).

Args:
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
num_local_computed_tokens: The number of local computed tokens.
num_external_computed_tokens: The number of external computed tokens.
"""
for i, manager in enumerate(self.single_type_managers):
manager.save_new_computed_blocks(request_id, new_computed_blocks[i])
manager.allocate_new_computed_blocks(
request_id,
new_computed_blocks[i],
num_local_computed_tokens,
num_external_computed_tokens,
)

def allocate_new_blocks(
self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0
self,
request_id: str,
num_tokens: int,
num_encoder_tokens: int = 0,
) -> tuple[list[KVCacheBlock], ...]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
Expand Down Expand Up @@ -184,17 +205,20 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
for manager in self.single_type_managers
]

def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
def remove_skipped_blocks(
self, request_id: str, total_computed_tokens: int
) -> None:
"""
Remove the blocks that are no longer needed from `blocks` and replace
the removed blocks with null_block.

Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
total_computed_tokens: The total number of computed tokens, including
local computed tokens and external computed tokens.
"""
for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens)
manager.remove_skipped_blocks(request_id, total_computed_tokens)

def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
"""
Expand Down
Loading