diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 985b97c69ca4..98aaf6135bba 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -36,6 +36,7 @@ is_kv_cache_spec_uniform, make_block_hash_with_group_id, tensor_data, + token_capacity_kv_cache_groups, ) from vllm.v1.kv_cache_interface import ( ChunkedLocalAttentionSpec, @@ -2214,3 +2215,82 @@ def test_hma_not_disabled_when_kv_events_enabled(): assert vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is False, ( "kv_events_config must not force-disable the hybrid KV cache manager." ) + + +def _vllm_config_with_mamba_mode(mamba_cache_mode: str) -> VllmConfig: + from vllm.config import CacheConfig + + return VllmConfig( + model_config=ModelConfig(max_model_len=16), + cache_config=CacheConfig(mamba_cache_mode=mamba_cache_mode), + ) + + +def _kv_cache_config(*groups: KVCacheGroupSpec) -> KVCacheConfig: + return KVCacheConfig( + num_blocks=0, + kv_cache_tensors=[], + kv_cache_groups=list(groups), + ) + + +def test_token_capacity_groups_dense_all_attention(): + """All-attention models: every group contributes to per-token capacity.""" + vllm_config = _vllm_config_with_mamba_mode("none") + attn = new_kv_cache_spec() + config = _kv_cache_config( + KVCacheGroupSpec(["a0"], attn), + KVCacheGroupSpec(["a1"], attn), + ) + assert token_capacity_kv_cache_groups(vllm_config, config) == config.kv_cache_groups + + +@pytest.mark.parametrize("mamba_mode", ["none", "align"]) +def test_token_capacity_groups_hybrid_excludes_o1_mamba(mamba_mode): + """Hybrid with mamba_cache_mode in ('none','align'): filter drops Mamba.""" + vllm_config = _vllm_config_with_mamba_mode(mamba_mode) + attn_group = KVCacheGroupSpec(["a0"], new_kv_cache_spec()) + mamba_group = KVCacheGroupSpec(["m0"], new_mamba_spec()) + config = _kv_cache_config(attn_group, mamba_group) + assert token_capacity_kv_cache_groups(vllm_config, config) == [attn_group] + + +def test_token_capacity_groups_hybrid_mamba_all_includes_mamba(): + """mamba_cache_mode='all': Mamba state scales with sequence length, kept.""" + vllm_config = _vllm_config_with_mamba_mode("all") + attn_group = KVCacheGroupSpec(["a0"], new_kv_cache_spec()) + mamba_group = KVCacheGroupSpec(["m0"], new_mamba_spec()) + config = _kv_cache_config(attn_group, mamba_group) + assert token_capacity_kv_cache_groups(vllm_config, config) == [ + attn_group, + mamba_group, + ] + + +def test_token_capacity_groups_mamba_only_falls_back(): + """Mamba-only model with mode='none' would filter to empty; keep all.""" + vllm_config = _vllm_config_with_mamba_mode("none") + mamba_groups = [ + KVCacheGroupSpec(["m0"], new_mamba_spec()), + KVCacheGroupSpec(["m1"], new_mamba_spec()), + ] + config = _kv_cache_config(*mamba_groups) + assert token_capacity_kv_cache_groups(vllm_config, config) == mamba_groups + + +def test_token_capacity_groups_multiple_attn_and_mamba(): + """Nemotron-H-style 1 attn + 3 Mamba groups: drops to 1 attn group.""" + vllm_config = _vllm_config_with_mamba_mode("none") + attn_group = KVCacheGroupSpec(["a0"], new_kv_cache_spec()) + mamba_groups = [ + KVCacheGroupSpec([f"m{i}"], new_mamba_spec()) for i in range(3) + ] + config = _kv_cache_config(attn_group, *mamba_groups) + assert token_capacity_kv_cache_groups(vllm_config, config) == [attn_group] + + +def test_token_capacity_groups_empty_config_returns_empty(): + """Edge case: no groups at all → empty list (no IndexError).""" + vllm_config = _vllm_config_with_mamba_mode("none") + config = _kv_cache_config() + assert token_capacity_kv_cache_groups(vllm_config, config) == [] diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index b57e10b67faa..8ecb7b596ea6 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -19,6 +19,7 @@ from vllm.utils.math_utils import cdiv, round_up from vllm.utils.mem_utils import format_gib from vllm.v1.kv_cache_interface import ( + AttentionSpec, ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheConfig, @@ -1683,6 +1684,31 @@ def generate_scheduler_kv_cache_config( return cfg +def token_capacity_kv_cache_groups( + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig +) -> list[KVCacheGroupSpec]: + """KV cache groups that contribute to per-token capacity. + + Attention groups always scale with sequence length. Mamba groups only + scale when ``mamba_cache_mode == 'all'``; in ``'none'`` and ``'align'`` + they hold O(1) state per request and pre-reserve a fixed number of + blocks, so counting them in the per-token divisor under-reports + capacity on hybrid models. + + Falls back to all groups if the filter would produce an empty list. + """ + mamba_scales = ( + getattr(vllm_config.cache_config, "mamba_cache_mode", "none") == "all" + ) + groups = [ + g + for g in kv_cache_config.kv_cache_groups + if isinstance(g.kv_cache_spec, AttentionSpec) + or (isinstance(g.kv_cache_spec, MambaSpec) and mamba_scales) + ] + return groups or list(kv_cache_config.kv_cache_groups) + + def _report_kv_cache_config( vllm_config: VllmConfig, kv_cache_config: KVCacheConfig ) -> None: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 395fa80bfe53..4154ddb4f0c4 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -38,6 +38,7 @@ ) from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector +from vllm.v1.core.kv_cache_utils import token_capacity_kv_cache_groups from vllm.v1.core.sched.interface import PauseState, SchedulerInterface from vllm.v1.core.sched.output import ( CachedRequestData, @@ -277,15 +278,12 @@ def __init__( if isinstance(group.kv_cache_spec, AttentionSpec): self.routed_experts_attn_gid = gid break - min_block_size = min( - [ - group.kv_cache_spec.block_size - for group in kv_cache_config.kv_cache_groups - ] + capacity_groups = token_capacity_kv_cache_groups( + self.vllm_config, kv_cache_config ) - num_groups = len(kv_cache_config.kv_cache_groups) + min_block_size = min(g.kv_cache_spec.block_size for g in capacity_groups) self.max_num_kv_tokens = ( - kv_cache_config.num_blocks // num_groups + kv_cache_config.num_blocks // len(capacity_groups) ) * min_block_size dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size