diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index d8ecf28cbed1..1887c04301e1 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1741,16 +1741,17 @@ def test_get_kv_cache_config_one_worker(): ], ) - # different hidden size that cannot be aligned by using different block size + # different hidden size that can't be evenly divided — handled via page + # padding (page_size_padded) instead of raising NotImplementedError kv_cache_specs_hybrid = { "layer_1": new_kv_cache_spec(head_size=64), "layer_2": new_sliding_window_spec(head_size=96), } - - with pytest.raises(NotImplementedError): - get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] - )[0] + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] + # Should succeed with padded page sizes + assert kv_cache_config_hybrid.num_blocks > 0 # Test num_gpu_blocks_override vllm_config.cache_config.num_gpu_blocks_override = 16 diff --git a/tests/v1/core/test_per_group_blockpool.py b/tests/v1/core/test_per_group_blockpool.py new file mode 100644 index 000000000000..49d858f5cf1e --- /dev/null +++ b/tests/v1/core/test_per_group_blockpool.py @@ -0,0 +1,424 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for per-group BlockPool allocation on hybrid models. + +Covers two real architectures: + - Qwen3.5 (GatedDeltaNet + full attention, every 4th layer) + - Nemotron-3-Nano (Mamba + MLP-only + full attention, 3 types) + +Verifies that O(1) groups (Mamba/GDN in none/align mode) get a small fixed +pool while O(n) groups (attention) get the bulk of memory, yielding +dramatically higher token capacity. +""" + +import pytest +import torch + +import vllm.v1.core.kv_cache_utils as kv_cache_utils +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig +from vllm.utils.math_utils import cdiv +from vllm.utils.mem_constants import GiB_bytes +from vllm.v1.core.kv_cache_utils import ( + get_kv_cache_configs, + get_max_concurrency_for_kv_cache_config, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheTensor, + MambaSpec, +) + +pytestmark = pytest.mark.cpu_test + +BLOCK_SIZE = 16 + + +# --------------------------------------------------------------------------- +# Spec builders for real architectures +# --------------------------------------------------------------------------- + +def _qwen35_specs( + kv_dtype=torch.bfloat16, + mamba_dtype=torch.bfloat16, + mamba_cache_mode="none", + num_layers=24, + attn_interval=4, +): + """Qwen3.5-0.8B/27B: GDN + full attention layers. + + 0.8B: 24 layers (18 GDN + 6 attn), kv_heads=2, head_dim=256 + 27B: 64 layers (48 GDN + 16 attn), kv_heads=4, head_dim=256 + """ + num_kv_heads = 2 if num_layers <= 24 else 4 + attn_spec = FullAttentionSpec( + block_size=BLOCK_SIZE, + num_kv_heads=num_kv_heads, + head_size=256, + dtype=kv_dtype, + ) + mamba_spec = MambaSpec( + block_size=BLOCK_SIZE, + shapes=((3, 8192), (32, 128, 128)), + dtypes=(mamba_dtype, mamba_dtype), + mamba_cache_mode=mamba_cache_mode, + ) + specs = {} + for i in range(num_layers): + if (i + 1) % attn_interval == 0: + specs[f"layer_{i}"] = attn_spec + else: + specs[f"layer_{i}"] = mamba_spec + + n_attn = sum(1 for s in specs.values() if isinstance(s, FullAttentionSpec)) + n_mamba = sum(1 for s in specs.values() if isinstance(s, MambaSpec)) + return specs, attn_spec, mamba_spec, n_attn, n_mamba + + +def _nemotron_specs( + kv_dtype=torch.bfloat16, + mamba_dtype=torch.bfloat16, + mamba_cache_mode="none", +): + """Nemotron-3-Nano-4B: Mamba + MLP-only + full attention. + + 42 layers: M=Mamba(21), -=MLP-only(17), *=full attention(4) + Pattern: M-M-M-MM-M-M*-M-M*-M-M-M*-M-M-MM*-MMM-M-M- + Attention: kv_heads=8, head_dim=128 + Mamba: 96 heads, ssm_state=128 + MLP-only layers have NO KV cache spec. + """ + pattern = "M-M-M-MM-M-M*-M-M*-M-M-M*-M-M-MM*-MMM-M-M-" + attn_spec = FullAttentionSpec( + block_size=BLOCK_SIZE, + num_kv_heads=8, + head_size=128, + dtype=kv_dtype, + ) + mamba_spec = MambaSpec( + block_size=BLOCK_SIZE, + shapes=((3, 8192), (96, 128, 128)), + dtypes=(mamba_dtype, mamba_dtype), + mamba_cache_mode=mamba_cache_mode, + ) + specs = {} + for i, c in enumerate(pattern): + if c == '*': + specs[f"layer_{i}"] = attn_spec + elif c == 'M': + specs[f"layer_{i}"] = mamba_spec + # '-' = MLP-only, no KV cache needed + + n_attn = sum(1 for s in specs.values() if isinstance(s, FullAttentionSpec)) + n_mamba = sum(1 for s in specs.values() if isinstance(s, MambaSpec)) + return specs, attn_spec, mamba_spec, n_attn, n_mamba + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + +def _get_config(max_model_len=1024, mamba_cache_mode=None, max_num_seqs=256): + model_config = ModelConfig(max_model_len=max_model_len) + scheduler_config = SchedulerConfig( + max_num_batched_tokens=max_model_len, + max_num_seqs=max_num_seqs, + enable_chunked_prefill=True, + max_model_len=max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ) + kwargs = dict(model_config=model_config, scheduler_config=scheduler_config) + if mamba_cache_mode is not None: + kwargs["cache_config"] = CacheConfig(mamba_cache_mode=mamba_cache_mode) + return VllmConfig(**kwargs) + + +def _total_page_bytes(specs): + return sum(s.page_size_bytes for s in specs.values()) + + +# --------------------------------------------------------------------------- +# Qwen3.5 tests +# --------------------------------------------------------------------------- + +class TestQwen35PerGroupBlockPool: + """Per-group BlockPool tests for Qwen3.5 architecture.""" + + def test_split_allocation_active(self): + """Compact allocation should produce per_group_num_blocks.""" + vllm_config = _get_config() + specs, attn_spec, mamba_spec, n_attn, n_mamba = _qwen35_specs() + mem = _total_page_bytes(specs) * 100 + + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + + assert config.per_group_num_blocks is not None, ( + "per_group_num_blocks should be set for hybrid models" + ) + assert len(config.per_group_num_blocks) == len(config.kv_cache_groups) + + def test_attention_gets_bulk_of_memory(self): + """Attention groups should get vastly more blocks than Mamba.""" + vllm_config = _get_config() + specs, attn_spec, mamba_spec, n_attn, n_mamba = _qwen35_specs() + mem = 10 * GiB_bytes + + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + + if config.per_group_num_blocks is not None: + for i, group in enumerate(config.kv_cache_groups): + if isinstance(group.kv_cache_spec, FullAttentionSpec): + attn_blocks = config.per_group_num_blocks[i] + elif isinstance(group.kv_cache_spec, MambaSpec): + mamba_blocks = config.per_group_num_blocks[i] + + assert attn_blocks > mamba_blocks * 10, ( + f"Attention ({attn_blocks}) should have >> Mamba ({mamba_blocks}) blocks" + ) + + def test_capacity_improvement_over_naive(self): + """Per-group allocation should yield >3x more attention token capacity.""" + vllm_config = _get_config() + specs, attn_spec, mamba_spec, n_attn, n_mamba = _qwen35_specs() + mem = 10 * GiB_bytes + + # Naive: all layers share same blocks + total_page = _total_page_bytes(specs) + naive_blocks = int(mem // total_page) + naive_tokens = naive_blocks * BLOCK_SIZE + + # Per-group + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + new_tokens = config.num_blocks * BLOCK_SIZE + + ratio = new_tokens / naive_tokens if naive_tokens > 0 else float('inf') + assert ratio >= 3.0, ( + f"Expected >=3x improvement, got {ratio:.1f}x " + f"(naive={naive_tokens}, new={new_tokens})" + ) + + def test_per_layer_tensors(self): + """Each layer should get its own tensor at its natural page size.""" + vllm_config = _get_config() + specs, attn_spec, mamba_spec, n_attn, n_mamba = _qwen35_specs() + mem = _total_page_bytes(specs) * 50 + + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + + total_layers = n_attn + n_mamba + assert len(config.kv_cache_tensors) == total_layers, ( + f"Expected {total_layers} tensors, got {len(config.kv_cache_tensors)}" + ) + for t in config.kv_cache_tensors: + assert len(t.shared_by) == 1 + + def test_mamba_mode_all_no_split(self): + """When mamba_cache_mode='all', no split -- Mamba is O(n) too.""" + vllm_config = _get_config(mamba_cache_mode="all") + specs, *_ = _qwen35_specs(mamba_cache_mode="all") + mem = _total_page_bytes(specs) * 50 + + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + + assert config.per_group_num_blocks is None, ( + "per_group_num_blocks should be None when mamba_cache_mode='all'" + ) + + def test_enough_blocks_for_max_model_len(self): + """Attention pool must have enough blocks for at least one full request.""" + max_model_len = 4096 + vllm_config = _get_config(max_model_len=max_model_len) + specs, attn_spec, *_ = _qwen35_specs() + mem = _total_page_bytes(specs) * 500 + + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + + blocks_needed = cdiv(max_model_len, BLOCK_SIZE) + assert config.num_blocks >= blocks_needed, ( + f"Need {blocks_needed} blocks for max_model_len={max_model_len}, " + f"got {config.num_blocks}" + ) + + @pytest.mark.parametrize("num_layers,attn_interval", [ + (24, 4), # 0.8B: 18 GDN + 6 attn + (64, 4), # 27B: 48 GDN + 16 attn + ]) + def test_scales_with_model_size(self, num_layers, attn_interval): + """Both 0.8B and 27B architectures should benefit from split.""" + vllm_config = _get_config() + specs, attn_spec, mamba_spec, n_attn, n_mamba = _qwen35_specs( + num_layers=num_layers, attn_interval=attn_interval + ) + mem = _total_page_bytes(specs) * 100 + + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + assert config.per_group_num_blocks is not None + assert config.num_blocks > 0 + + def test_pure_attention_unaffected(self): + """Pure attention model should not trigger per-group split.""" + vllm_config = _get_config() + attn_spec = FullAttentionSpec( + block_size=BLOCK_SIZE, num_kv_heads=8, head_size=128, + dtype=torch.bfloat16, + ) + specs = {f"layer_{i}": attn_spec for i in range(32)} + mem = attn_spec.page_size_bytes * 32 * 100 + + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + + assert config.per_group_num_blocks is None, ( + "Pure attention should not use per-group split" + ) + + +# --------------------------------------------------------------------------- +# Nemotron-3-Nano tests +# --------------------------------------------------------------------------- + +class TestNemotronPerGroupBlockPool: + """Per-group BlockPool tests for Nemotron-3-Nano architecture. + + 42 layers with 3 types: Mamba(21) + MLP-only(17) + attention(4). + Only 4 attention layers -- per-group split is critical. + """ + + def test_split_allocation_active(self): + """Should activate per-group split despite 3 layer types.""" + vllm_config = _get_config() + specs, *_ = _nemotron_specs() + mem = _total_page_bytes(specs) * 100 + + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + + assert config.per_group_num_blocks is not None + + def test_only_kv_layers_have_specs(self): + """MLP-only layers should not appear in KV cache specs.""" + specs, attn_spec, mamba_spec, n_attn, n_mamba = _nemotron_specs() + + assert n_attn == 4, f"Expected 4 attention layers, got {n_attn}" + assert n_mamba == 21, f"Expected 21 Mamba layers, got {n_mamba}" + assert len(specs) == 25, ( + f"Expected 25 KV layers (no MLP-only), got {len(specs)}" + ) + + def test_massive_capacity_improvement(self): + """Nemotron should see >10x capacity improvement from split allocation.""" + vllm_config = _get_config() + specs, attn_spec, mamba_spec, n_attn, n_mamba = _nemotron_specs() + mem = 10 * GiB_bytes + + # Naive + total_page = _total_page_bytes(specs) + naive_blocks = int(mem // total_page) + naive_tokens = naive_blocks * BLOCK_SIZE + + # Per-group + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + new_tokens = config.num_blocks * BLOCK_SIZE + + ratio = new_tokens / naive_tokens if naive_tokens > 0 else float('inf') + assert ratio >= 10.0, ( + f"Expected >=10x improvement for Nemotron, got {ratio:.1f}x " + f"(naive={naive_tokens:,}, new={new_tokens:,})" + ) + + def test_attention_dominates_allocation(self): + """Attention blocks should vastly outnumber Mamba blocks.""" + vllm_config = _get_config() + specs, attn_spec, mamba_spec, n_attn, n_mamba = _nemotron_specs() + mem = 10 * GiB_bytes + + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + + if config.per_group_num_blocks is not None: + for i, group in enumerate(config.kv_cache_groups): + if isinstance(group.kv_cache_spec, FullAttentionSpec): + attn_blocks = config.per_group_num_blocks[i] + elif isinstance(group.kv_cache_spec, MambaSpec): + mamba_blocks = config.per_group_num_blocks[i] + + assert attn_blocks > mamba_blocks * 5, ( + f"Attention ({attn_blocks}) should have >> " + f"Mamba ({mamba_blocks}) blocks" + ) + + def test_mamba_pool_sized_for_concurrency(self): + """Mamba pool should be sized for max concurrent requests.""" + max_num_seqs = 128 + vllm_config = _get_config(max_num_seqs=max_num_seqs) + specs, *_ = _nemotron_specs() + mem = 10 * GiB_bytes + + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + + if config.per_group_num_blocks is not None: + for i, group in enumerate(config.kv_cache_groups): + if isinstance(group.kv_cache_spec, MambaSpec): + mamba_blocks = config.per_group_num_blocks[i] + assert mamba_blocks <= max_num_seqs * 3, ( + f"Mamba pool too large: {mamba_blocks} blocks " + f"(max_num_seqs={max_num_seqs})" + ) + assert mamba_blocks >= max_num_seqs, ( + f"Mamba pool too small: {mamba_blocks} blocks " + f"(need at least {max_num_seqs})" + ) + + +# --------------------------------------------------------------------------- +# Cross-architecture tests +# --------------------------------------------------------------------------- + +class TestPerGroupBlockPoolGeneral: + """Tests that apply to any hybrid architecture.""" + + @pytest.mark.parametrize("make_specs", [_qwen35_specs, _nemotron_specs], + ids=["qwen35", "nemotron"]) + def test_allocation_efficient(self, make_specs): + """Total allocation should use >85% of available memory.""" + vllm_config = _get_config() + specs, *_ = make_specs() + mem = 10 * GiB_bytes + + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + + total_allocated = sum(t.size for t in config.kv_cache_tensors) + efficiency = total_allocated / mem + assert efficiency > 0.85, ( + f"Allocation efficiency {efficiency:.1%} < 85%" + ) + + @pytest.mark.parametrize("make_specs", [_qwen35_specs, _nemotron_specs], + ids=["qwen35", "nemotron"]) + def test_backward_compatible_when_no_split(self, make_specs): + """When mamba_cache_mode='all', behaves like shared pool.""" + vllm_config = _get_config(mamba_cache_mode="all") + specs, *_ = make_specs(mamba_cache_mode="all") + mem = _total_page_bytes(specs) * 50 + + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + + assert config.per_group_num_blocks is None + + @pytest.mark.parametrize("make_specs", [_qwen35_specs, _nemotron_specs], + ids=["qwen35", "nemotron"]) + def test_num_blocks_consistent_with_tensors(self, make_specs): + """num_blocks should match the actual tensor sizes.""" + vllm_config = _get_config() + specs, attn_spec, *_ = make_specs() + mem = _total_page_bytes(specs) * 100 + + config = get_kv_cache_configs(vllm_config, [specs], [mem])[0] + + for t in config.kv_cache_tensors: + layer_name = t.shared_by[0] + spec = specs[layer_name] + if isinstance(spec, FullAttentionSpec): + expected = spec.page_size_bytes * config.num_blocks + assert t.size == expected, ( + f"{layer_name}: tensor size {t.size} != " + f"page_size({spec.page_size_bytes}) * " + f"num_blocks({config.num_blocks}) = {expected}" + ) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index eaa95dfe49f7..f931c74ebc02 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -11,6 +11,7 @@ BlockHashList, BlockHashListWithBlockSize, KVCacheBlock, + make_block_hash_with_group_id, ) from vllm.v1.core.single_type_kv_cache_manager import ( CrossAttentionManager, @@ -25,6 +26,31 @@ from vllm.v1.request import Request +class _MultiGroupPoolView: + """Wraps per-group BlockPools to present a unified get_cached_block + interface for prefix caching lookups across groups with separate pools.""" + + def __init__(self, group_pools: dict[int, BlockPool]): + self.group_pools = group_pools + + def get_cached_block( + self, block_hash: BlockHash, kv_cache_group_ids: list[int] + ) -> list[KVCacheBlock] | None: + cached_blocks = [] + for gid in kv_cache_group_ids: + pool = self.group_pools[gid] + block_hash_with_gid = make_block_hash_with_group_id( + block_hash, gid + ) + block = pool.cached_block_hash_to_block.get_one_block( + block_hash_with_gid + ) + if not block: + return None + cached_blocks.append(block) + return cached_blocks + + class KVCacheCoordinator(ABC): """ Coordinate the KV cache of different KV cache groups. @@ -46,20 +72,44 @@ def __init__( self.max_model_len = max_model_len self.enable_caching = enable_caching - self.block_pool = BlockPool( - kv_cache_config.num_blocks, - enable_caching, - hash_block_size, - enable_kv_cache_events, - metrics_collector, - ) + per_group_nb = kv_cache_config.per_group_num_blocks + num_groups = len(kv_cache_config.kv_cache_groups) + + if per_group_nb is not None: + assert len(per_group_nb) == num_groups + # Create per-group pools. Each group gets its own BlockPool + # so block IDs are independent per group. + self.group_pools: list[BlockPool] = [ + BlockPool( + nb, + enable_caching, + hash_block_size, + enable_kv_cache_events, + metrics_collector, + ) + for nb in per_group_nb + ] + # Primary pool: the largest (attention) pool, used for usage + # reporting, events, and backward compatibility. + self.block_pool = max( + self.group_pools, key=lambda p: p.num_gpu_blocks + ) + else: + self.block_pool = BlockPool( + kv_cache_config.num_blocks, + enable_caching, + hash_block_size, + enable_kv_cache_events, + metrics_collector, + ) + self.group_pools = [self.block_pool] * num_groups # Needs special handling for find_longest_cache_hit if eagle is enabled self.use_eagle = use_eagle self.single_type_managers = tuple( get_manager_for_kv_cache_spec( kv_cache_spec=kv_cache_group.kv_cache_spec, - block_pool=self.block_pool, + block_pool=self.group_pools[i], enable_caching=enable_caching, kv_cache_group_id=i, dcp_world_size=dcp_world_size, @@ -114,6 +164,38 @@ def get_num_blocks_to_allocate( ) return num_blocks_to_allocate + def has_enough_blocks( + self, + request_id: str, + num_tokens: int, + new_computed_blocks: tuple[Sequence[KVCacheBlock], ...], + num_encoder_tokens: int, + total_computed_tokens: int, + num_tokens_main_model: int, + ) -> bool: + """ + Check if all per-group pools have enough free blocks for this request. + + With per-group BlockPools, each manager's allocation must be checked + against its own pool. Returns True if all pools can accommodate. + """ + for i, manager in enumerate(self.single_type_managers): + if isinstance(manager, CrossAttentionManager): + needed = manager.get_num_blocks_to_allocate( + request_id, num_encoder_tokens, [], 0, num_encoder_tokens + ) + else: + needed = manager.get_num_blocks_to_allocate( + request_id, + num_tokens, + new_computed_blocks[i], + total_computed_tokens, + num_tokens_main_model, + ) + if needed > self.group_pools[i].get_num_free_blocks(): + return False + return True + def allocate_new_computed_blocks( self, request_id: str, @@ -511,11 +593,19 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: num_blocks = curr_hit_length // spec.block_size curr_hit_length = num_blocks * spec.block_size else: + # When groups have separate pools, create a view that + # dispatches get_cached_block to per-group pools. + if self.kv_cache_config.per_group_num_blocks is not None: + pool_view = _MultiGroupPoolView( + {gid: self.group_pools[gid] for gid in group_ids} + ) + else: + pool_view = self.block_pool hit_blocks = manager_cls.find_longest_cache_hit( block_hashes=_get_block_hashes(spec), max_length=curr_hit_length, kv_cache_group_ids=group_ids, - block_pool=self.block_pool, + block_pool=pool_view, kv_cache_spec=spec, use_eagle=self.use_eagle, alignment_tokens=self.lcm_block_size, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index dcec5e05bf97..2a6b4dc2ffab 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -142,6 +142,13 @@ def __init__( self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool self.kv_cache_config = kv_cache_config + # Collect unique pools for multi-pool operations (reset, events) + seen_ids: set[int] = set() + self._unique_pools: list = [] + for pool in self.coordinator.group_pools: + if id(pool) not in seen_ids: + seen_ids.add(id(pool)) + self._unique_pools.append(pool) # Pre-constructed KVCacheBlocks with no blocks, callers should use this # via create_kv_cache_blocks instead of creating new ones to avoid GC @@ -158,8 +165,10 @@ def usage(self) -> float: Returns: The KV cache usage (between 0.0 and 1.0). + With per-group pools, returns the max usage across all pools + since any full pool blocks new requests. """ - return self.block_pool.get_usage() + return max(pool.get_usage() for pool in self._unique_pools) def make_prefix_cache_stats(self) -> PrefixCacheStats | None: """Get (and reset) the prefix cache stats. @@ -374,7 +383,7 @@ def allocate_slots( request.request_id, total_computed_tokens ) - num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( + if not self.coordinator.has_enough_blocks( request_id=request.request_id, num_tokens=num_tokens_need_slot, new_computed_blocks=new_computed_block_list, @@ -382,10 +391,8 @@ def allocate_slots( total_computed_tokens=num_local_computed_tokens + num_external_computed_tokens, num_tokens_main_model=num_tokens_main_model, - ) - - if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): - # Cannot allocate new blocks + ): + # Cannot allocate new blocks in one or more pools return None if ( @@ -455,7 +462,23 @@ def evict_blocks(self, block_ids: set[int]) -> None: Args: block_ids: Set of block IDs to evict from cache. """ - self.block_pool.evict_blocks(block_ids) + if len(self._unique_pools) == 1: + self._unique_pools[0].evict_blocks(block_ids) + else: + # With per-group pools, block IDs are pool-local (both start + # from 0). We must evict only from the pool that owns each + # block. Build a reverse map from block_id -> pool by + # checking which pool's cached_block_hash_to_block contains + # the block. + for pool in self._unique_pools: + pool_ids = set() + for bid in block_ids: + if bid < pool.num_gpu_blocks: + blk = pool.blocks[bid] + if blk.block_hash is not None: + pool_ids.add(bid) + if pool_ids: + pool.evict_blocks(pool_ids) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -466,7 +489,7 @@ def reset_prefix_cache(self) -> bool: bool: True if the prefix cache is successfully reset, False otherwise. """ - if not self.block_pool.reset_prefix_cache(): + if not all(pool.reset_prefix_cache() for pool in self._unique_pools): return False if self.log_stats: assert self.prefix_cache_stats is not None @@ -513,7 +536,10 @@ def take_events(self) -> list[KVCacheEvent]: Returns: A list of KV cache events. """ - return self.block_pool.take_events() + events: list[KVCacheEvent] = [] + for pool in self._unique_pools: + events.extend(pool.take_events()) + return events def get_blocks(self, request_id: str) -> KVCacheBlocks: """Get the blocks of a request.""" diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9ab5af0f6fb0..afc1ac97090a 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -24,6 +24,7 @@ KVCacheGroupSpec, KVCacheSpec, KVCacheTensor, + MambaSpec, SlidingWindowSpec, UniformTypeKVCacheSpecs, ) @@ -939,9 +940,32 @@ def unify_kv_cache_spec_page_size( else: layer_page_size = layer_spec.page_size_bytes if max_page_size % layer_page_size != 0: + # Page sizes not directly divisible. Increase block_size + # so this layer's page fills the max page as closely as + # possible, then pad the remainder. This minimizes waste + # for compressed caches in hybrid (attention + GDN) models. + per_token = layer_page_size // layer_spec.block_size + if per_token > 0: + new_block_size = (max_page_size // per_token // 16) * 16 + if new_block_size < 16: + new_block_size = 16 + # Safety check: ensure tokens fit within padded page + if new_block_size * per_token > max_page_size: + raise NotImplementedError( + f"Cannot unify page sizes: {new_block_size} " + f"tokens x {per_token} bytes/token = " + f"{new_block_size * per_token} exceeds " + f"max_page_size {max_page_size}." + ) + new_spec = replace( + layer_spec, + block_size=new_block_size, + page_size_padded=max_page_size, + ) + new_kv_cache_spec[layer_name] = new_spec + continue raise NotImplementedError( - "The page size of the layer is not divisible by the " - "maximum page size. Cannot unify by adjusting block_size." + "Cannot unify page sizes: per-token bytes is zero." ) ratio = max_page_size // layer_page_size new_block_size = layer_spec.block_size * ratio @@ -1104,6 +1128,7 @@ def get_kv_cache_config_from_groups( ) # Determine how model runners should initialize the KV cache tensors. + per_group_num_blocks = None if len(kv_cache_groups) == 1 and isinstance( kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs ): @@ -1123,37 +1148,109 @@ def get_kv_cache_config_from_groups( for layer_name in kv_cache_groups[0].layer_names ] else: - # General case: - # We will have group_size memory pools, each is shared by one layer from - # each group. As layers of different groups have different block table, - # they will use different parts of the shared Tensor. - # The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2), - # (sw.1, padding) will be: (group_size = 2) - # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2 - # full.1, sw.2: share another Tensor with size=available_memory//2 + # General case: multiple groups with potentially different page sizes. + # Each layer gets its own tensor. O(1) groups (mamba in none/align + # mode) get a fixed-size allocation based on max_num_seqs, freeing + # the remaining memory for O(n) groups (attention) to maximize + # KV cache token capacity. group_size = max(len(group.layer_names) for group in kv_cache_groups) - - page_size = get_uniform_page_size( - [group.kv_cache_spec for group in kv_cache_groups] - ) assert group_size > 0, "group_size must be greater than 0" - num_blocks = get_num_blocks( - vllm_config, group_size, available_memory, page_size - ) - kv_cache_tensors = [] - for i in range(group_size): - shared_by = [] - for j in range(len(kv_cache_groups)): - if i < len(kv_cache_groups[j].layer_names): - shared_by.append(kv_cache_groups[j].layer_names[i]) - kv_cache_tensors.append( - KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by) + + # Separate O(1) groups (mamba) from O(n) groups (attention) + o1_groups = [] # Fixed allocation: max_seqs blocks + on_groups = [] # Dynamic allocation: as many blocks as memory allows + for group in kv_cache_groups: + spec = group.kv_cache_spec + if (isinstance(spec, MambaSpec) + and spec.mamba_cache_mode != "all"): + o1_groups.append(group) + else: + on_groups.append(group) + + if o1_groups and on_groups: + # Split allocation: fixed mamba pool + dynamic attention pool + max_seqs = vllm_config.scheduler_config.max_num_seqs + # Mamba needs 1 block per sequence (2 for "align" with speculation) + mamba_blocks_per_seq = max( + 1 + g.kv_cache_spec.num_speculative_blocks + for g in o1_groups + if isinstance(g.kv_cache_spec, MambaSpec) + ) + # +1 for the null_block that BlockPool always reserves + mamba_blocks = max_seqs * mamba_blocks_per_seq + 1 + mamba_per_slot = sum( + g.kv_cache_spec.page_size_bytes for g in o1_groups + ) + mamba_memory = mamba_per_slot * mamba_blocks + + # Remaining memory for attention + attn_memory = available_memory - mamba_memory + if attn_memory < 0: + attn_memory = 0 + attn_per_slot = sum( + g.kv_cache_spec.page_size_bytes for g in on_groups + ) + attn_group_size = max( + len(g.layer_names) for g in on_groups + ) if on_groups else 1 + num_blocks = int( + attn_memory // attn_per_slot // attn_group_size + ) if attn_per_slot > 0 else 0 + num_blocks = max(num_blocks, 0) + num_blocks = may_override_num_blocks(vllm_config, num_blocks) + + # Build per_group_num_blocks: map each group to its block count + o1_set = set(id(g) for g in o1_groups) + per_group_num_blocks = [ + mamba_blocks if id(g) in o1_set else num_blocks + for g in kv_cache_groups + ] + + # Build tensors: attention layers get num_blocks, mamba gets + # mamba_blocks + kv_cache_tensors = [] + for i in range(group_size): + for group in on_groups: + if i < len(group.layer_names): + kv_cache_tensors.append(KVCacheTensor( + size=group.kv_cache_spec.page_size_bytes + * num_blocks, + shared_by=[group.layer_names[i]], + )) + for group in o1_groups: + if i < len(group.layer_names): + kv_cache_tensors.append(KVCacheTensor( + size=group.kv_cache_spec.page_size_bytes + * mamba_blocks, + shared_by=[group.layer_names[i]], + )) + else: + # No split needed — all groups are O(n). Use the original + # shared-tensor approach: group_size pools, each shared by + # one layer from each group. + page_size = get_uniform_page_size( + [group.kv_cache_spec for group in kv_cache_groups] + ) + num_blocks = get_num_blocks( + vllm_config, group_size, available_memory, page_size ) + kv_cache_tensors = [] + for i in range(group_size): + shared_by = [] + for j in range(len(kv_cache_groups)): + if i < len(kv_cache_groups[j].layer_names): + shared_by.append(kv_cache_groups[j].layer_names[i]) + kv_cache_tensors.append( + KVCacheTensor( + size=page_size * num_blocks, shared_by=shared_by + ) + ) return KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=kv_cache_tensors, kv_cache_groups=kv_cache_groups, + per_group_num_blocks=per_group_num_blocks, ) @@ -1251,14 +1348,28 @@ def get_kv_cache_groups( # same window size). Put all layers into one group. return _get_kv_cache_groups_uniform_type(uniform_spec) - # As KVCacheManager can only allocate memory of one size, we need to unify - # the page size of the layers. For cases cannot be unified, this function - # will raise an error. + # Check if the model has O(1) mamba groups + O(n) attention groups. + # If so, skip page unification — per-group BlockPool gives each group + # its own tensors, so they don't need matching page sizes. This avoids + # inflating attention pages (e.g., 32KB) to match large mamba pages + # (e.g., 1MB), which would waste 97% of attention memory. + has_o1_mamba = any( + isinstance(spec, MambaSpec) and spec.mamba_cache_mode != "all" + for spec in kv_cache_spec.values() + ) + has_on_attn = any( + not isinstance(spec, MambaSpec) + for spec in kv_cache_spec.values() + ) + if has_o1_mamba and has_on_attn: + # Skip page unification — each group keeps its natural page size. + # The split allocator in get_kv_cache_config_from_groups handles + # the different page sizes via per-group BlockPools. + return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) + + # For non-mamba hybrids (e.g., full attention + sliding window), + # unify page sizes for the shared tensor layout. kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec) - # Model contains multiple attention types, but KV cache of all layers - # have the same physical memory per block per layer. Split the layers - # into groups with the same number of layers, and thus same total page - # size. return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) @@ -1299,11 +1410,28 @@ def _report_kv_cache_config( ) # Log the KV cache size and maximum concurrency. - num_tokens = ( - kv_cache_config.num_blocks - // len(kv_cache_config.kv_cache_groups) - * min_block_size - ) + # With per-group pools, count tokens from the attention (O(n)) groups only, + # since mamba groups have fixed O(1) allocation per sequence. + if kv_cache_config.per_group_num_blocks is not None: + # Use attention blocks (the primary num_blocks) and count only + # attention groups for the division. + attn_group_count = sum( + 1 for g, nb in zip( + kv_cache_config.kv_cache_groups, + kv_cache_config.per_group_num_blocks, + ) + if nb == kv_cache_config.num_blocks + ) + attn_group_count = max(attn_group_count, 1) + num_tokens = ( + kv_cache_config.num_blocks // attn_group_count * min_block_size + ) + else: + num_tokens = ( + kv_cache_config.num_blocks + // len(kv_cache_config.kv_cache_groups) + * min_block_size + ) dcp_size = vllm_config.parallel_config.decode_context_parallel_size pcp_size = vllm_config.parallel_config.prefill_context_parallel_size if pcp_size * dcp_size > 1: @@ -1353,18 +1481,25 @@ def _max_memory_usage_bytes_from_groups( for spec in per_layer_specs.values() ) - # General case: group_size pools, each shared by one layer per group - # Memory = group_size * page_size * blocks_for_max_len + # General case: sum max memory across all groups. + # For uniform-page hybrids: group_size * page_size * blocks_needed. + # For split hybrids (O(1) mamba + O(n) attention): sum per-group. group_size = max(len(group.layer_names) for group in kv_cache_groups) - page_size = get_uniform_page_size( - [group.kv_cache_spec for group in kv_cache_groups] - ) - blocks_needed = sum( - cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size) - for group in kv_cache_groups - ) - - return group_size * page_size * blocks_needed + page_sizes = set(g.kv_cache_spec.page_size_bytes for g in kv_cache_groups) + if len(page_sizes) == 1: + page_size = page_sizes.pop() + any_spec = kv_cache_groups[0].kv_cache_spec + blocks_needed = cdiv( + any_spec.max_memory_usage_bytes(vllm_config), page_size + ) + return group_size * page_size * blocks_needed + else: + # Non-uniform pages: sum each group's max usage independently + return sum( + len(g.layer_names) + * g.kv_cache_spec.max_memory_usage_bytes(vllm_config) + for g in kv_cache_groups + ) def _estimate_max_model_len_from_groups( @@ -1599,20 +1734,65 @@ def get_kv_cache_configs( # Change the num_blocks of each rank to the smallest among all ranks. # We also need to shrink the tensor size proportionally to avoid # allocating unused memory. - min_num_blocks = min( - kv_cache_config.num_blocks for kv_cache_config in kv_cache_configs + has_per_group = any( + c.per_group_num_blocks is not None for c in kv_cache_configs ) - for kv_cache_config in kv_cache_configs: - num_blocks_old = kv_cache_config.num_blocks - kv_cache_config.num_blocks = min_num_blocks - # Shrink tensor size proportionally - for tensor in kv_cache_config.kv_cache_tensors: - assert tensor.size % num_blocks_old == 0 - tensor.size = tensor.size // num_blocks_old * min_num_blocks + if has_per_group: + # Per-group block counts: shrink each group independently. + # Build layer_name → group_id mapping from the first config. + ref_config = kv_cache_configs[0] + layer_to_group: dict[str, int] = {} + for gid, group in enumerate(ref_config.kv_cache_groups): + for layer_name in group.layer_names: + layer_to_group[layer_name] = gid + + num_groups = len(ref_config.kv_cache_groups) + # Find min per-group block count across workers + min_per_group = [ + min( + c.per_group_num_blocks[g] # type: ignore[index] + for c in kv_cache_configs + ) + for g in range(num_groups) + ] + + for kv_cache_config in kv_cache_configs: + old_per_group = list(kv_cache_config.per_group_num_blocks) # type: ignore[arg-type] + kv_cache_config.per_group_num_blocks = min_per_group + # num_blocks tracks the primary (attention) block count + kv_cache_config.num_blocks = max(min_per_group) if min_per_group else 0 + + # Shrink each tensor based on its group's block count + for tensor in kv_cache_config.kv_cache_tensors: + gid = layer_to_group[tensor.shared_by[0]] + old_nb = old_per_group[gid] + new_nb = min_per_group[gid] + if old_nb > 0: + assert tensor.size % old_nb == 0, ( + f"Tensor size {tensor.size} not divisible by " + f"old num_blocks {old_nb} for group {gid}" + ) + tensor.size = tensor.size // old_nb * new_nb + + if len(kv_cache_config.kv_cache_groups) > 0: + _report_kv_cache_config(vllm_config, kv_cache_config) + else: + # Original single-pool logic + min_num_blocks = min( + kv_cache_config.num_blocks for kv_cache_config in kv_cache_configs + ) + for kv_cache_config in kv_cache_configs: + num_blocks_old = kv_cache_config.num_blocks + kv_cache_config.num_blocks = min_num_blocks + + # Shrink tensor size proportionally + for tensor in kv_cache_config.kv_cache_tensors: + assert tensor.size % num_blocks_old == 0 + tensor.size = tensor.size // num_blocks_old * min_num_blocks - if len(kv_cache_config.kv_cache_groups) > 0: - _report_kv_cache_config(vllm_config, kv_cache_config) + if len(kv_cache_config.kv_cache_groups) > 0: + _report_kv_cache_config(vllm_config, kv_cache_config) return kv_cache_configs diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index fe524ccace16..9498daf05a28 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -280,8 +280,16 @@ def __init__( ] ) num_groups = len(kv_cache_config.kv_cache_groups) + if kv_cache_config.per_group_num_blocks is not None: + # With per-group pools, count only attention (O(n)) groups + attn_group_count = max(1, sum( + 1 for nb in kv_cache_config.per_group_num_blocks + if nb == kv_cache_config.num_blocks + )) + else: + attn_group_count = num_groups self.max_num_kv_tokens = ( - kv_cache_config.num_blocks // num_groups + kv_cache_config.num_blocks // attn_group_count ) * min_block_size dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 62bdb8113a32..e302c85fa77b 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -1106,15 +1106,26 @@ def __init__( self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block) -spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { - FullAttentionSpec: FullAttentionManager, - MLAAttentionSpec: FullAttentionManager, - SlidingWindowSpec: SlidingWindowManager, - ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, - MambaSpec: MambaManager, - CrossAttentionSpec: CrossAttentionManager, - SinkFullAttentionSpec: SinkFullAttentionManager, -} +def _build_spec_manager_map(): + m: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { + FullAttentionSpec: FullAttentionManager, + MLAAttentionSpec: FullAttentionManager, + SlidingWindowSpec: SlidingWindowManager, + ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, + MambaSpec: MambaManager, + CrossAttentionSpec: CrossAttentionManager, + SinkFullAttentionSpec: SinkFullAttentionManager, + } + try: + from vllm.v1.attention.backends.turboquant_attn import ( + TurboQuantFullAttentionSpec, + ) + m[TurboQuantFullAttentionSpec] = FullAttentionManager + except ImportError: + pass + return m + +spec_manager_map = _build_spec_manager_map() def get_manager_for_kv_cache_spec( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 6f8ad8e7d8ef..fed40565076f 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -538,7 +538,7 @@ class KVCacheConfig: """ num_blocks: int - """The number of KV cache blocks""" + """The number of KV cache blocks (for the primary/attention groups)""" kv_cache_tensors: list[KVCacheTensor] """How should model runner initialize the KV cache tensors for each layer""" kv_cache_groups: list[KVCacheGroupSpec] @@ -549,6 +549,12 @@ class KVCacheConfig: For models with multiple types of attention, there will be multiple groups, see `_get_kv_cache_config_uniform_page_size` for more details. """ + per_group_num_blocks: list[int] | None = None + """Per-group block counts. When set, each KV cache group gets its own + BlockPool with the specified number of blocks. This enables split + allocation where O(1) groups (e.g., mamba in none/align mode) get a + small fixed pool and O(n) groups (attention) get the rest of memory. + If None, all groups share a single pool with num_blocks blocks.""" @property def has_mamba_layers(self) -> bool: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7a21117fb64c..2c6c9bcf1792 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6646,9 +6646,17 @@ def _reshape_kv_cache_tensors( kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] + raw = kv_cache_raw_tensors[layer_name].view(dtype) + shape_numel = 1 + for s in kv_cache_shape: + shape_numel *= s + if raw.numel() > shape_numel: + # Padded allocation (e.g. compressed KV cache aligned + # to recurrent layer page size). Use only the needed + # portion; the rest is padding. + raw = raw.narrow(0, 0, shape_numel) kv_caches[layer_name] = ( - kv_cache_raw_tensors[layer_name] - .view(dtype) + raw .view(kv_cache_shape) .permute(*inv_order) ) @@ -6878,8 +6886,15 @@ def init_routed_experts_capturer(self): ] ) num_groups = len(self.kv_cache_config.kv_cache_groups) + if self.kv_cache_config.per_group_num_blocks is not None: + attn_group_count = max(1, sum( + 1 for nb in self.kv_cache_config.per_group_num_blocks + if nb == self.kv_cache_config.num_blocks + )) + else: + attn_group_count = num_groups self.max_num_kv_tokens = ( - self.kv_cache_config.num_blocks // num_groups + self.kv_cache_config.num_blocks // attn_group_count ) * min_block_size dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size