Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
is_kv_cache_spec_uniform,
make_block_hash_with_group_id,
tensor_data,
token_capacity_kv_cache_groups,
)
from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
Expand Down Expand Up @@ -2214,3 +2215,82 @@ def test_hma_not_disabled_when_kv_events_enabled():
assert vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is False, (
"kv_events_config must not force-disable the hybrid KV cache manager."
)


def _vllm_config_with_mamba_mode(mamba_cache_mode: str) -> VllmConfig:
from vllm.config import CacheConfig

return VllmConfig(
model_config=ModelConfig(max_model_len=16),
cache_config=CacheConfig(mamba_cache_mode=mamba_cache_mode),
)


def _kv_cache_config(*groups: KVCacheGroupSpec) -> KVCacheConfig:
return KVCacheConfig(
num_blocks=0,
kv_cache_tensors=[],
kv_cache_groups=list(groups),
)


def test_token_capacity_groups_dense_all_attention():
"""All-attention models: every group contributes to per-token capacity."""
vllm_config = _vllm_config_with_mamba_mode("none")
attn = new_kv_cache_spec()
config = _kv_cache_config(
KVCacheGroupSpec(["a0"], attn),
KVCacheGroupSpec(["a1"], attn),
)
assert token_capacity_kv_cache_groups(vllm_config, config) == config.kv_cache_groups


@pytest.mark.parametrize("mamba_mode", ["none", "align"])
def test_token_capacity_groups_hybrid_excludes_o1_mamba(mamba_mode):
"""Hybrid with mamba_cache_mode in ('none','align'): filter drops Mamba."""
vllm_config = _vllm_config_with_mamba_mode(mamba_mode)
attn_group = KVCacheGroupSpec(["a0"], new_kv_cache_spec())
mamba_group = KVCacheGroupSpec(["m0"], new_mamba_spec())
config = _kv_cache_config(attn_group, mamba_group)
assert token_capacity_kv_cache_groups(vllm_config, config) == [attn_group]


def test_token_capacity_groups_hybrid_mamba_all_includes_mamba():
"""mamba_cache_mode='all': Mamba state scales with sequence length, kept."""
vllm_config = _vllm_config_with_mamba_mode("all")
attn_group = KVCacheGroupSpec(["a0"], new_kv_cache_spec())
mamba_group = KVCacheGroupSpec(["m0"], new_mamba_spec())
config = _kv_cache_config(attn_group, mamba_group)
assert token_capacity_kv_cache_groups(vllm_config, config) == [
attn_group,
mamba_group,
]


def test_token_capacity_groups_mamba_only_falls_back():
"""Mamba-only model with mode='none' would filter to empty; keep all."""
vllm_config = _vllm_config_with_mamba_mode("none")
mamba_groups = [
KVCacheGroupSpec(["m0"], new_mamba_spec()),
KVCacheGroupSpec(["m1"], new_mamba_spec()),
]
config = _kv_cache_config(*mamba_groups)
assert token_capacity_kv_cache_groups(vllm_config, config) == mamba_groups


def test_token_capacity_groups_multiple_attn_and_mamba():
"""Nemotron-H-style 1 attn + 3 Mamba groups: drops to 1 attn group."""
vllm_config = _vllm_config_with_mamba_mode("none")
attn_group = KVCacheGroupSpec(["a0"], new_kv_cache_spec())
mamba_groups = [
KVCacheGroupSpec([f"m{i}"], new_mamba_spec()) for i in range(3)
]
config = _kv_cache_config(attn_group, *mamba_groups)
assert token_capacity_kv_cache_groups(vllm_config, config) == [attn_group]


def test_token_capacity_groups_empty_config_returns_empty():
"""Edge case: no groups at all → empty list (no IndexError)."""
vllm_config = _vllm_config_with_mamba_mode("none")
config = _kv_cache_config()
assert token_capacity_kv_cache_groups(vllm_config, config) == []
26 changes: 26 additions & 0 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.mem_utils import format_gib
from vllm.v1.kv_cache_interface import (
AttentionSpec,
ChunkedLocalAttentionSpec,
FullAttentionSpec,
KVCacheConfig,
Expand Down Expand Up @@ -1683,6 +1684,31 @@ def generate_scheduler_kv_cache_config(
return cfg


def token_capacity_kv_cache_groups(
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
) -> list[KVCacheGroupSpec]:
"""KV cache groups that contribute to per-token capacity.

Attention groups always scale with sequence length. Mamba groups only
scale when ``mamba_cache_mode == 'all'``; in ``'none'`` and ``'align'``
they hold O(1) state per request and pre-reserve a fixed number of
blocks, so counting them in the per-token divisor under-reports
capacity on hybrid models.

Falls back to all groups if the filter would produce an empty list.
"""
mamba_scales = (
getattr(vllm_config.cache_config, "mamba_cache_mode", "none") == "all"
)
groups = [
g
for g in kv_cache_config.kv_cache_groups
if isinstance(g.kv_cache_spec, AttentionSpec)
or (isinstance(g.kv_cache_spec, MambaSpec) and mamba_scales)
]
return groups or list(kv_cache_config.kv_cache_groups)


def _report_kv_cache_config(
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
) -> None:
Expand Down
12 changes: 5 additions & 7 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
from vllm.v1.core.kv_cache_utils import token_capacity_kv_cache_groups
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import (
CachedRequestData,
Expand Down Expand Up @@ -277,15 +278,12 @@ def __init__(
if isinstance(group.kv_cache_spec, AttentionSpec):
self.routed_experts_attn_gid = gid
break
min_block_size = min(
[
group.kv_cache_spec.block_size
for group in kv_cache_config.kv_cache_groups
]
capacity_groups = token_capacity_kv_cache_groups(
self.vllm_config, kv_cache_config
)
num_groups = len(kv_cache_config.kv_cache_groups)
min_block_size = min(g.kv_cache_spec.block_size for g in capacity_groups)
self.max_num_kv_tokens = (
kv_cache_config.num_blocks // num_groups
kv_cache_config.num_blocks // len(capacity_groups)
) * min_block_size
dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size
pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size
Expand Down
Loading