Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
from collections.abc import Callable
from math import lcm
from types import SimpleNamespace

import pytest
import torch
Expand Down Expand Up @@ -33,6 +34,7 @@
init_none_hash,
make_block_hash_with_group_id,
)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
Expand Down Expand Up @@ -1023,6 +1025,66 @@ def test_prefill_hybrid_model_mamba_align():
manager.free(req0)


def test_hybrid_cache_mamba_align_shared_prefix_detection():
"""Test shared prefix detection heuristic for mamba align cache mode

HybridKVCacheCoordinator returns num_uncached_common > 0 when a shared
uncached prefix is detected. With mamba_align cache, _mamba_block_aligned_split
enforces scheduling aligned with the common prefix.
"""
block_size = 16
manager = make_kv_cache_manager(
_make_hybrid_kv_cache_config(block_size, 30, ["full", "mamba_align"]),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
hash_fn = sha256

# Request: 3 blocks
prefix = [i for i in range(3) for _ in range(block_size)]
req_0 = make_request("0", prefix, block_size, hash_fn)
computed_blocks, num_computed = manager.get_computed_blocks(req_0)
num_uncached_common = manager.coordinator.num_uncached_common_prefix_tokens
assert num_computed == 0 # nothing cached yet
assert num_uncached_common == 0
manager.allocate_slots(req_0, 3 * block_size, 0, computed_blocks)

# Request: 3 blocks (shared with above) + 7 different tokens
req_1 = make_request("1", prefix + [100] * 7, block_size, hash_fn)
computed_blocks, num_computed = manager.get_computed_blocks(req_1)
num_uncached_common = manager.coordinator.num_uncached_common_prefix_tokens
assert num_computed == 3 * block_size # we should observe a 3-block cache hit
assert num_uncached_common == 0
manager.allocate_slots(req_1, 7, 3 * block_size, computed_blocks)

# Request: 3 blocks, but only 2 blocks shared (replace the last token in 3rd block):
req_2 = make_request("2", prefix[:-1] + [101], block_size, hash_fn)
computed_blocks, num_computed = manager.get_computed_blocks(req_2)
num_uncached_common = manager.coordinator.num_uncached_common_prefix_tokens
assert num_computed == 0 # mamba_align doesn't cache intermediate blocks
assert num_uncached_common == 2 * block_size # heuristic detects a shared prefix

# Next, validate scheduler logic for num_uncached_common_prefix_tokens > 0
# Create minimal mock with just the needed attributes
mock = SimpleNamespace(
cache_config=SimpleNamespace(block_size=block_size), use_eagle=False
)
num_new_tokens_adjusted = Scheduler._mamba_block_aligned_split(
self=mock,
request=req_2,
num_new_tokens=3 * block_size,
num_uncached_common_prefix_tokens=num_uncached_common,
)
assert num_new_tokens_adjusted == 2 * block_size # adjust to the common prefix

manager.allocate_slots(req_2, 3 * block_size, 0, computed_blocks)
# Cleanup
manager.free(req_0)
manager.free(req_1)
manager.free(req_2)


def test_hybrid_model_mamba_align_with_dynamic_draft_tokens():
"""Regression test for https://github.com/vllm-project/vllm/issues/39271.

Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/core/kv_cache_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList:

num_groups = len(self.kv_cache_config.kv_cache_groups)
hit_length = max_cache_hit_length
longest_hit_length = 0
hit_blocks_by_group: list[list[KVCacheBlock] | None] = [None] * num_groups

# Simple hybrid (1 full attn + 1 other): one iteration suffices.
Expand Down Expand Up @@ -667,6 +668,8 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList:
for group_id, blocks in zip(group_ids, hit_blocks):
hit_blocks_by_group[group_id] = blocks

longest_hit_length = max(longest_hit_length, curr_hit_length)
Comment thread
s3woz marked this conversation as resolved.

if curr_hit_length >= hit_length:
break
hit_length = curr_hit_length
Expand All @@ -681,6 +684,9 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList:
if (blks := hit_blocks_by_group[group_id]) is not None:
del blks[num_blocks:]

# Uncached shared prefix detection: If any attn. group cached a longer prefix
# than the current prefix, it is an uncached common prefix across requests:
self.num_uncached_common_prefix_tokens = longest_hit_length - hit_length
return tuple(
blocks if blocks is not None else [] for blocks in hit_blocks_by_group
), hit_length
Expand Down
21 changes: 21 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def _mamba_block_aligned_split(
num_new_tokens: int,
num_new_local_computed_tokens: int = 0,
num_external_computed_tokens: int = 0,
num_uncached_common_prefix_tokens: int = 0,
) -> int:
num_computed_tokens = (
request.num_computed_tokens
Expand Down Expand Up @@ -335,6 +336,16 @@ def _mamba_block_aligned_split(
else:
# prefill the last few tokens
pass

# Marconi cache admission optimization:
# cache common prefixes by scheduling num_new_tokens = common prefix length
if (
num_uncached_common_prefix_tokens >= block_size
Comment thread
tdoublep marked this conversation as resolved.
and num_new_tokens > num_uncached_common_prefix_tokens
):
num_new_tokens = num_uncached_common_prefix_tokens
# keep alignment to block_size
num_new_tokens = num_new_tokens // block_size * block_size
return num_new_tokens

def schedule(self) -> SchedulerOutput:
Expand Down Expand Up @@ -604,6 +615,7 @@ def schedule(self) -> SchedulerOutput:
num_external_computed_tokens = 0
load_kv_async = False
connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0
num_uncached_common_prefix_tokens = 0

# Get already-cached tokens.
if request.num_computed_tokens == 0:
Expand All @@ -612,6 +624,14 @@ def schedule(self) -> SchedulerOutput:
self.kv_cache_manager.get_computed_blocks(request)
)

# In case of hybrid models, obtain hint for Marconi-style APC logic
if self.has_mamba_layers:
num_uncached_common_prefix_tokens = getattr(
self.kv_cache_manager.coordinator,
"num_uncached_common_prefix_tokens",
0,
)

# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
ext_tokens, load_kv_async = (
Expand Down Expand Up @@ -724,6 +744,7 @@ def schedule(self) -> SchedulerOutput:
num_new_tokens,
num_new_local_computed_tokens,
num_external_computed_tokens,
num_uncached_common_prefix_tokens,
Comment thread
depthfirst-app[bot] marked this conversation as resolved.
)
if num_new_tokens == 0:
break
Expand Down
Loading