From 956b14b4f1c0fc00d89e1e91fad2cc5438e7d1a5 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Sat, 2 May 2026 10:16:52 +0900 Subject: [PATCH 1/6] Fix hybrid Mamba attention KV cache allocation Add explicit KV cache memory-model metadata, compact request-constant block pools, and pool-aware config/manager/worker handling for hybrid Mamba and attention models. Mamba cache mode 'all' keeps the legacy token-proportional path. Unsupported request-constant combinations fail closed for prefix caching, offload, connector, and full CUDA graph paths. Co-authored-by: OpenAI Codex Signed-off-by: lesj0610 --- tests/v1/core/test_block_pool.py | 172 ++++++ tests/v1/core/test_kv_cache_coordinator.py | 376 ++++++++++++ tests/v1/core/test_kv_cache_invariants.py | 188 ++++++ tests/v1/core/test_kv_cache_utils.py | 570 +++++++++++++++++- tests/v1/core/test_prefix_caching.py | 103 ++-- tests/v1/core/test_scheduler.py | 93 +++ .../core/test_single_type_kv_cache_manager.py | 211 ++++++- tests/v1/simple_kv_offload/test_scheduler.py | 70 +++ tests/v1/worker/test_gpu_model_runner.py | 310 ++++++++++ vllm/config/compilation.py | 20 + vllm/platforms/interface.py | 10 + vllm/platforms/xpu.py | 25 +- vllm/v1/core/block_pool.py | 111 +++- vllm/v1/core/kv_cache_coordinator.py | 133 +++- vllm/v1/core/kv_cache_manager.py | 61 +- vllm/v1/core/kv_cache_utils.py | 488 +++++++++++++-- vllm/v1/core/sched/scheduler.py | 33 +- vllm/v1/core/single_type_kv_cache_manager.py | 96 ++- vllm/v1/kv_cache_interface.py | 223 ++++++- vllm/v1/simple_kv_offload/manager.py | 14 +- vllm/v1/worker/gpu/attn_utils.py | 20 +- vllm/v1/worker/gpu_model_runner.py | 18 +- .../worker/kv_connector_model_runner_mixin.py | 14 +- 23 files changed, 3192 insertions(+), 167 deletions(-) create mode 100644 tests/v1/core/test_block_pool.py create mode 100644 tests/v1/core/test_kv_cache_coordinator.py create mode 100644 tests/v1/core/test_kv_cache_invariants.py diff --git a/tests/v1/core/test_block_pool.py b/tests/v1/core/test_block_pool.py new file mode 100644 index 000000000000..2521cb2d0a7b --- /dev/null +++ b/tests/v1/core/test_block_pool.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.v1.core.block_pool import BlockPool, CompactBlockPool + +pytestmark = pytest.mark.cpu_test + + +def test_block_pool_reserves_zero_as_null_block(): + block_pool = BlockPool( + num_gpu_blocks=4, + enable_caching=False, + hash_block_size=16, + ) + + assert block_pool.null_block.block_id == 0 + assert block_pool.null_block.is_null + assert block_pool.get_num_free_blocks() == 3 + assert block_pool.get_usage() == 0 + + +def test_block_pool_allocation_never_returns_null_block(): + block_pool = BlockPool( + num_gpu_blocks=4, + enable_caching=False, + hash_block_size=16, + ) + + blocks = block_pool.get_new_blocks(3) + + assert {block.block_id for block in blocks} == {1, 2, 3} + assert all(not block.is_null for block in blocks) + assert block_pool.get_num_free_blocks() == 0 + assert block_pool.get_usage() == 1 + + +def test_block_pool_exhaustion_raises_without_allocating_null_block(): + block_pool = BlockPool( + num_gpu_blocks=2, + enable_caching=False, + hash_block_size=16, + ) + + with pytest.raises(ValueError, match="Cannot get 2 free blocks"): + block_pool.get_new_blocks(2) + + block = block_pool.get_new_blocks(1)[0] + assert block.block_id == 1 + assert block_pool.null_block.block_id == 0 + + +def test_block_pool_free_returns_blocks_but_not_null_block(): + block_pool = BlockPool( + num_gpu_blocks=3, + enable_caching=False, + hash_block_size=16, + ) + blocks = block_pool.get_new_blocks(2) + + block_pool.free_blocks(reversed(blocks)) + + assert block_pool.get_num_free_blocks() == 2 + assert block_pool.get_usage() == 0 + reallocated = block_pool.get_new_blocks(2) + assert {block.block_id for block in reallocated} == {1, 2} + + +@pytest.mark.parametrize( + "pool", + [ + BlockPool(num_gpu_blocks=4, enable_caching=False, hash_block_size=16), + CompactBlockPool(num_allocatable=3), + ], +) +def test_block_pool_protocol_conformance(pool): + assert pool.num_gpu_blocks == 4 + assert pool.null_block.block_id == 0 + assert pool.null_block.is_null + + blocks = pool.get_new_blocks(1) + + assert len(blocks) == 1 + assert blocks[0].block_id != 0 + assert not blocks[0].is_null + pool.free_blocks(blocks) + + +def test_compact_block_pool_reserves_zero_as_null_block(): + block_pool = CompactBlockPool(num_allocatable=3) + + assert block_pool.num_gpu_blocks == 4 + assert block_pool.null_block.block_id == 0 + assert block_pool.null_block.is_null + assert block_pool.get_num_free_blocks() == 3 + assert block_pool.get_usage() == 0 + + +def test_compact_block_pool_allocation_never_returns_null_block(): + block_pool = CompactBlockPool(num_allocatable=3) + + blocks = block_pool.get_new_blocks(3) + + assert {block.block_id for block in blocks} == {1, 2, 3} + assert all(not block.is_null for block in blocks) + assert all(block.ref_cnt == 1 for block in blocks) + assert block_pool.get_num_free_blocks() == 0 + assert block_pool.get_usage() == 1 + + +def test_compact_block_pool_zero_arg_operations_are_noops(): + block_pool = CompactBlockPool(num_allocatable=1) + + assert block_pool.get_new_blocks(0) == [] + block_pool.free_blocks([]) + + assert block_pool.get_num_free_blocks() == 1 + assert block_pool.get_usage() == 0 + + +def test_compact_block_pool_exhaustion_matches_shared_pool_error_type(): + block_pool = CompactBlockPool(num_allocatable=1) + + with pytest.raises(ValueError, match="Cannot get 2 free blocks"): + block_pool.get_new_blocks(2) + + block = block_pool.get_new_blocks(1)[0] + assert block.block_id == 1 + assert block_pool.null_block.block_id == 0 + + +def test_compact_block_pool_free_returns_blocks_to_pool(): + block_pool = CompactBlockPool(num_allocatable=2) + blocks = block_pool.get_new_blocks(2) + + block_pool.free_blocks(reversed(blocks)) + + assert all(block.ref_cnt == 0 for block in blocks) + assert block_pool.get_num_free_blocks() == 2 + assert block_pool.get_usage() == 0 + reallocated = block_pool.get_new_blocks(2) + assert {block.block_id for block in reallocated} == {1, 2} + assert all(block.ref_cnt == 1 for block in reallocated) + + +def test_compact_block_pool_rejects_freeing_null_block(): + block_pool = CompactBlockPool(num_allocatable=1) + + with pytest.raises(AssertionError, match="null block must never be freed"): + block_pool.free_blocks([block_pool.null_block]) + + +def test_compact_block_pool_rejects_freeing_unallocated_block(): + block_pool = CompactBlockPool(num_allocatable=1) + block = block_pool.get_new_blocks(1)[0] + block_pool.free_blocks([block]) + + with pytest.raises(AssertionError, match="binary ref_cnt semantics"): + block_pool.free_blocks([block]) + + +def test_compact_block_pool_allows_zero_allocatable_blocks(): + block_pool = CompactBlockPool(num_allocatable=0) + + assert block_pool.num_gpu_blocks == 1 + assert block_pool.null_block.block_id == 0 + assert block_pool.get_num_free_blocks() == 0 + assert block_pool.get_usage() == 0 + assert block_pool.get_new_blocks(0) == [] + with pytest.raises(ValueError, match="Cannot get 1 free blocks"): + block_pool.get_new_blocks(1) diff --git a/tests/v1/core/test_kv_cache_coordinator.py b/tests/v1/core/test_kv_cache_coordinator.py new file mode 100644 index 000000000000..7b16abb79221 --- /dev/null +++ b/tests/v1/core/test_kv_cache_coordinator.py @@ -0,0 +1,376 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import Mock + +import pytest +import torch + +from vllm.sampling_params import SamplingParams +from vllm.v1.core.block_pool import BlockPool, CompactBlockPool +from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator +from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCachePoolConfig, + KVCacheTensor, + MambaSpec, + MemoryModel, + SlidingWindowSpec, +) +from vllm.v1.request import Request + +pytestmark = pytest.mark.cpu_test + + +def make_full_attention_spec(block_size: int = 16) -> FullAttentionSpec: + return FullAttentionSpec( + block_size=block_size, + num_kv_heads=2, + head_size=64, + dtype=torch.float32, + ) + + +def make_sliding_window_spec(block_size: int = 16) -> SlidingWindowSpec: + return SlidingWindowSpec( + block_size=block_size, + num_kv_heads=2, + head_size=64, + dtype=torch.float32, + sliding_window=128, + ) + + +def make_mamba_spec(block_size: int = 16) -> MambaSpec: + return MambaSpec( + block_size=block_size, + shapes=((1,),), + dtypes=(torch.float32,), + ) + + +def make_coordinator(config: KVCacheConfig, enable_caching: bool = False): + return get_kv_cache_coordinator( + kv_cache_config=config, + max_model_len=128, + max_num_batched_tokens=128, + use_eagle=False, + enable_caching=enable_caching, + enable_kv_cache_events=False, + dcp_world_size=1, + pcp_world_size=1, + hash_block_size=16, + ) + + +def make_kv_cache_manager_config( + block_size: int = 4, num_blocks: int = 3 +) -> KVCacheConfig: + return KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(["layer"], make_full_attention_spec(block_size)) + ], + ) + + +def make_multi_pool_kv_cache_manager_config() -> KVCacheConfig: + block_size = 4 + full_spec = make_full_attention_spec(block_size) + mamba_spec = make_mamba_spec(block_size) + return KVCacheConfig( + num_blocks=3, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(["attention_layer"], full_spec), + KVCacheGroupSpec(["mamba_layer"], mamba_spec), + ], + pool_configs=( + KVCachePoolConfig( + pool_id=0, + memory_model=MemoryModel.TOKEN_PROPORTIONAL, + group_ids=(0,), + num_blocks=3, + accounting_page_size_bytes=full_spec.accounting_page_size_bytes, + physical_page_size_bytes=full_spec.physical_page_size_bytes, + ), + KVCachePoolConfig( + pool_id=1, + memory_model=MemoryModel.REQUEST_CONSTANT, + group_ids=(1,), + num_blocks=2, + accounting_page_size_bytes=mamba_spec.accounting_page_size_bytes, + physical_page_size_bytes=mamba_spec.physical_page_size_bytes, + ), + ), + group_to_pool_id=(0, 1), + ) + + +def make_request(request_id: str = "request", num_tokens: int = 4) -> Request: + return Request( + request_id=request_id, + prompt_token_ids=[0] * num_tokens, + sampling_params=SamplingParams(max_tokens=1), + pooling_params=None, + ) + + +def test_attention_free_config_keeps_legacy_shared_pool_alias(): + coordinator = make_coordinator( + KVCacheConfig(num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[]) + ) + + assert coordinator.block_pool is coordinator._block_pools[0] + assert coordinator._group_to_pool == () + assert coordinator.single_type_managers == () + assert coordinator.get_num_free_blocks_by_pool() == () + # The legacy coordinator still owns one shared BlockPool, but + # attention-free configs have no KV cache pool metadata and no managers. + assert ( + coordinator.get_num_blocks_to_allocate_by_pool( + request_id="request", + num_tokens=32, + new_computed_blocks=(), + num_encoder_tokens=0, + total_computed_tokens=0, + num_tokens_main_model=32, + ) + == () + ) + assert ( + coordinator.get_num_blocks_to_allocate( + request_id="request", + num_tokens=32, + new_computed_blocks=(), + num_encoder_tokens=0, + total_computed_tokens=0, + num_tokens_main_model=32, + ) + == 0 + ) + + +def test_single_group_config_maps_manager_to_legacy_pool_alias(): + config = KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[KVCacheTensor(size=100, shared_by=["layer_0"])], + kv_cache_groups=[KVCacheGroupSpec(["layer_0"], make_full_attention_spec())], + ) + + coordinator = make_coordinator(config, enable_caching=True) + + assert coordinator.block_pool is coordinator._block_pools[0] + assert coordinator._group_to_pool == (coordinator.block_pool,) + assert coordinator.single_type_managers[0].block_pool is coordinator.block_pool + assert coordinator.get_num_free_blocks_by_pool() == (9,) + assert coordinator.get_num_blocks_to_allocate_by_pool( + request_id="request", + num_tokens=32, + new_computed_blocks=((),), + num_encoder_tokens=0, + total_computed_tokens=0, + num_tokens_main_model=32, + ) == (2,) + assert ( + coordinator.get_num_blocks_to_allocate( + request_id="request", + num_tokens=32, + new_computed_blocks=((),), + num_encoder_tokens=0, + total_computed_tokens=0, + num_tokens_main_model=32, + ) + == 2 + ) + + +def test_multi_group_single_pool_config_maps_all_managers_to_legacy_pool_alias(): + config = KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer_0"]), + KVCacheTensor(size=100, shared_by=["layer_1"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_0"], make_full_attention_spec()), + KVCacheGroupSpec(["layer_1"], make_sliding_window_spec()), + ], + ) + + coordinator = make_coordinator(config) + + assert coordinator.block_pool is coordinator._block_pools[0] + assert coordinator._group_to_pool == ( + coordinator.block_pool, + coordinator.block_pool, + ) + assert all( + manager.block_pool is coordinator.block_pool + for manager in coordinator.single_type_managers + ) + assert coordinator.get_num_free_blocks_by_pool() == (9,) + assert coordinator.get_num_blocks_to_allocate_by_pool( + request_id="request", + num_tokens=32, + new_computed_blocks=((), ()), + num_encoder_tokens=0, + total_computed_tokens=0, + num_tokens_main_model=32, + ) == (4,) + assert ( + coordinator.get_num_blocks_to_allocate( + request_id="request", + num_tokens=32, + new_computed_blocks=((), ()), + num_encoder_tokens=0, + total_computed_tokens=0, + num_tokens_main_model=32, + ) + == 4 + ) + + +def test_multi_pool_config_without_prefix_cache_builds_configured_pools(): + full_spec = make_full_attention_spec() + mamba_spec = make_mamba_spec() + config = KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_0"], full_spec), + KVCacheGroupSpec(["layer_1"], mamba_spec), + ], + pool_configs=( + KVCachePoolConfig( + pool_id=0, + memory_model=MemoryModel.TOKEN_PROPORTIONAL, + group_ids=(0,), + num_blocks=10, + accounting_page_size_bytes=full_spec.accounting_page_size_bytes, + physical_page_size_bytes=full_spec.physical_page_size_bytes, + ), + KVCachePoolConfig( + pool_id=1, + memory_model=MemoryModel.REQUEST_CONSTANT, + group_ids=(1,), + num_blocks=3, + accounting_page_size_bytes=mamba_spec.accounting_page_size_bytes, + physical_page_size_bytes=mamba_spec.physical_page_size_bytes, + ), + ), + group_to_pool_id=(0, 1), + ) + + coordinator = make_coordinator(config, enable_caching=False) + + assert isinstance(coordinator._block_pools[0], BlockPool) + assert isinstance(coordinator._block_pools[1], CompactBlockPool) + assert coordinator.block_pool is coordinator._block_pools[0] + assert coordinator._group_to_pool == coordinator._block_pools + assert coordinator.get_num_free_blocks_by_pool() == (9, 2) + # Full attention remains token-proportional: 32 tokens / block_size 16 = 2. + # Request-constant Mamba uses one compact state block per request. + assert coordinator.get_num_blocks_to_allocate_by_pool( + request_id="request", + num_tokens=32, + new_computed_blocks=((), ()), + num_encoder_tokens=0, + total_computed_tokens=0, + num_tokens_main_model=32, + ) == (2, 1) + + +def test_has_enough_free_blocks_by_pool_uses_pool_tuple(): + manager = KVCacheManager( + kv_cache_config=make_kv_cache_manager_config(num_blocks=3), + max_model_len=16, + hash_block_size=4, + enable_caching=False, + ) + + # num_blocks=3 includes the null sentinel, leaving two allocatable blocks. + assert manager.coordinator.get_num_free_blocks_by_pool() == (2,) + assert manager._has_enough_free_blocks_by_pool((2,)) + assert not manager._has_enough_free_blocks_by_pool((3,)) + + +def test_allocate_slots_uses_pool_aware_accounting_path(): + manager = KVCacheManager( + kv_cache_config=make_kv_cache_manager_config(num_blocks=3), + max_model_len=16, + hash_block_size=4, + enable_caching=False, + ) + request = make_request(num_tokens=4) + manager.coordinator.get_num_blocks_to_allocate = Mock( + side_effect=AssertionError("legacy scalar accounting should not be used") + ) + get_num_blocks_to_allocate_by_pool = Mock( + wraps=manager.coordinator.get_num_blocks_to_allocate_by_pool + ) + manager.coordinator.get_num_blocks_to_allocate_by_pool = ( + get_num_blocks_to_allocate_by_pool + ) + + blocks = manager.allocate_slots(request, num_new_tokens=4) + + assert blocks is not None + assert blocks.get_block_ids() == ([1],) + manager.coordinator.get_num_blocks_to_allocate.assert_not_called() + get_num_blocks_to_allocate_by_pool.assert_called_once() + + +def test_allocate_slots_checks_each_pool_independently(): + manager = KVCacheManager( + kv_cache_config=make_multi_pool_kv_cache_manager_config(), + max_model_len=16, + hash_block_size=4, + enable_caching=False, + ) + + blocks = manager.allocate_slots(make_request("first", num_tokens=4), 4) + + assert blocks is not None + assert blocks.get_block_ids() == ([1], [1]) + assert manager.coordinator.get_num_free_blocks_by_pool() == (1, 0) + assert manager.allocate_slots(make_request("second", num_tokens=4), 4) is None + + +def test_multi_pool_config_with_prefix_cache_is_rejected_until_pool_aware(): + full_spec = make_full_attention_spec() + sliding_spec = make_sliding_window_spec() + config = KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_0"], full_spec), + KVCacheGroupSpec(["layer_1"], sliding_spec), + ], + pool_configs=( + KVCachePoolConfig( + pool_id=0, + memory_model=MemoryModel.TOKEN_PROPORTIONAL, + group_ids=(0,), + num_blocks=10, + accounting_page_size_bytes=full_spec.accounting_page_size_bytes, + physical_page_size_bytes=full_spec.physical_page_size_bytes, + ), + KVCachePoolConfig( + pool_id=1, + memory_model=MemoryModel.TOKEN_PROPORTIONAL, + group_ids=(1,), + num_blocks=10, + accounting_page_size_bytes=sliding_spec.accounting_page_size_bytes, + physical_page_size_bytes=sliding_spec.physical_page_size_bytes, + ), + ), + group_to_pool_id=(0, 1), + ) + + with pytest.raises(NotImplementedError, match="prefix caching is disabled"): + make_coordinator(config, enable_caching=True) diff --git a/tests/v1/core/test_kv_cache_invariants.py b/tests/v1/core/test_kv_cache_invariants.py new file mode 100644 index 000000000000..72b6df5eba92 --- /dev/null +++ b/tests/v1/core/test_kv_cache_invariants.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import pytest +import torch + +from vllm.utils.torch_utils import get_dtype_size +from vllm.v1.kv_cache_interface import ( + ChunkedLocalAttentionSpec, + CrossAttentionSpec, + EncoderOnlyAttentionSpec, + FullAttentionSpec, + KVCacheSpec, + MambaSpec, + MemoryModel, + MLAAttentionSpec, + SinkFullAttentionSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) + +pytestmark = pytest.mark.cpu_test + + +def _full_attention_spec() -> FullAttentionSpec: + return FullAttentionSpec( + block_size=16, + num_kv_heads=2, + head_size=64, + dtype=torch.float16, + ) + + +def _mla_attention_spec() -> MLAAttentionSpec: + return MLAAttentionSpec( + block_size=16, + num_kv_heads=1, + head_size=64, + dtype=torch.float16, + ) + + +def _chunked_local_attention_spec() -> ChunkedLocalAttentionSpec: + return ChunkedLocalAttentionSpec( + block_size=16, + num_kv_heads=2, + head_size=64, + dtype=torch.float16, + attention_chunk_size=128, + ) + + +def _sliding_window_spec() -> SlidingWindowSpec: + return SlidingWindowSpec( + block_size=16, + num_kv_heads=2, + head_size=64, + dtype=torch.float16, + sliding_window=128, + ) + + +def _mamba_spec() -> MambaSpec: + return MambaSpec( + block_size=16, + shapes=((4, 8), (2, 8)), + dtypes=(torch.float16, torch.float32), + ) + + +def _encoder_only_attention_spec() -> EncoderOnlyAttentionSpec: + return EncoderOnlyAttentionSpec( + block_size=16, + num_kv_heads=2, + head_size=64, + dtype=torch.float16, + ) + + +def _cross_attention_spec() -> CrossAttentionSpec: + return CrossAttentionSpec( + block_size=16, + num_kv_heads=2, + head_size=64, + dtype=torch.float16, + ) + + +def _sink_full_attention_spec() -> SinkFullAttentionSpec: + return SinkFullAttentionSpec( + block_size=16, + num_kv_heads=2, + head_size=64, + dtype=torch.float16, + sink_len=4, + ) + + +def _uniform_type_kv_cache_specs() -> UniformTypeKVCacheSpecs: + return UniformTypeKVCacheSpecs( + block_size=16, + kv_cache_specs={ + "layer.0": _full_attention_spec(), + "layer.1": _full_attention_spec(), + }, + ) + + +TOKEN_PROPORTIONAL_SPEC_FACTORIES: list[tuple[str, Callable[[], KVCacheSpec]]] = [ + ("FullAttentionSpec", _full_attention_spec), + ("MLAAttentionSpec", _mla_attention_spec), + ("ChunkedLocalAttentionSpec", _chunked_local_attention_spec), + ("SlidingWindowSpec", _sliding_window_spec), + ("EncoderOnlyAttentionSpec", _encoder_only_attention_spec), + ("CrossAttentionSpec", _cross_attention_spec), + ("SinkFullAttentionSpec", _sink_full_attention_spec), + ("UniformTypeKVCacheSpecs", _uniform_type_kv_cache_specs), +] + + +@pytest.mark.parametrize( + ("spec_name", "make_spec"), + TOKEN_PROPORTIONAL_SPEC_FACTORIES, + ids=[name for name, _ in TOKEN_PROPORTIONAL_SPEC_FACTORIES], +) +def test_default_memory_model_is_token_proportional( + spec_name: str, make_spec: Callable[[], KVCacheSpec] +) -> None: + spec = make_spec() + + assert spec.memory_model == MemoryModel.TOKEN_PROPORTIONAL, spec_name + assert spec.accounting_page_size_bytes == spec.page_size_bytes, spec_name + assert spec.requires_block_zeroing_on_alloc is True, spec_name + + +@pytest.mark.parametrize("mamba_cache_mode", ["none", "align", "all"]) +def test_mamba_zeroing_metadata_matches_current_zeroer( + mamba_cache_mode: str, +) -> None: + spec = MambaSpec( + block_size=16, + shapes=((4, 8), (2, 8)), + dtypes=(torch.float16, torch.float32), + mamba_cache_mode=mamba_cache_mode, + ) + + expected_memory_model = ( + MemoryModel.TOKEN_PROPORTIONAL + if mamba_cache_mode == "all" + else MemoryModel.REQUEST_CONSTANT + ) + assert spec.memory_model == expected_memory_model + assert spec.accounting_page_size_bytes == spec.page_size_bytes + assert spec.requires_block_zeroing_on_alloc is False + + +def test_mamba_physical_page_size_excludes_accounting_padding() -> None: + spec = MambaSpec( + block_size=16, + shapes=((4, 8), (2, 8)), + dtypes=(torch.float16, torch.float32), + page_size_padded=1024, + ) + expected_physical = (4 * 8 * get_dtype_size(torch.float16)) + ( + 2 * 8 * get_dtype_size(torch.float32) + ) + + assert spec.physical_page_size_bytes == expected_physical + assert spec.page_size_bytes == 1024 + assert spec.accounting_page_size_bytes == 1024 + + +def test_uniform_type_physical_page_size_sums_children() -> None: + full_spec = _full_attention_spec() + mla_spec = _mla_attention_spec() + uniform_spec = UniformTypeKVCacheSpecs( + block_size=16, + kv_cache_specs={ + "full": full_spec, + "mla": mla_spec, + }, + ) + + assert uniform_spec.physical_page_size_bytes == ( + full_spec.physical_page_size_bytes + mla_spec.physical_page_size_bytes + ) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 985b97c69ca4..5cd4ec96d527 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -3,6 +3,7 @@ import hashlib import importlib from collections.abc import Callable +from dataclasses import dataclass from typing import Any import pytest @@ -10,6 +11,7 @@ import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.config import ModelConfig, SchedulerConfig, VllmConfig +from vllm.config.compilation import CompilationConfig, CUDAGraphMode from vllm.config.kv_events import KVEventsConfig from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import ( @@ -20,6 +22,7 @@ from vllm.sampling_params import SamplingParams from vllm.utils.hashing import sha256, sha256_cbor from vllm.utils.mem_constants import GiB_bytes +from vllm.v1.attention.backend import AttentionCGSupport from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_utils import ( BlockHash, @@ -31,6 +34,7 @@ get_kv_cache_configs, get_max_concurrency_for_kv_cache_config, get_request_block_hasher, + get_token_proportional_kv_cache_capacity_tokens, hash_block_tokens, init_none_hash, is_kv_cache_spec_uniform, @@ -42,9 +46,11 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, + KVCachePoolConfig, KVCacheSpec, KVCacheTensor, MambaSpec, + MemoryModel, MLAAttentionSpec, SlidingWindowSpec, UniformTypeKVCacheSpecs, @@ -177,6 +183,67 @@ def new_mamba_spec( ) +@dataclass(frozen=True) +class _DummyRequestConstantSpec(MambaSpec): + """Test-only spec for generated request-constant config paths.""" + + @property + def memory_model(self) -> MemoryModel: + return MemoryModel.REQUEST_CONSTANT + + @property + def blocks_per_request(self) -> int: + return 1 + self.num_speculative_blocks + + +def new_request_constant_spec( + block_size=16, + shapes=((2,),), + dtypes=(torch.float32,), + num_speculative_blocks=0, + page_size_padded=None, +): + return _DummyRequestConstantSpec( + block_size=block_size, + shapes=shapes, + dtypes=dtypes, + num_speculative_blocks=num_speculative_blocks, + page_size_padded=page_size_padded, + ) + + +def assert_legacy_single_pool_metadata(config: KVCacheConfig) -> None: + if len(config.kv_cache_groups) == 0: + assert config.pool_configs == () + assert config.group_to_pool_id == () + assert config.num_blocks == 1 + return + + assert len(config.pool_configs) == 1 + pool_config = config.pool_configs[0] + assert pool_config.pool_id == 0 + assert pool_config.memory_model == MemoryModel.TOKEN_PROPORTIONAL + assert pool_config.group_ids == tuple(range(len(config.kv_cache_groups))) + assert config.group_to_pool_id == tuple(0 for _ in config.kv_cache_groups) + assert pool_config.num_blocks == config.num_blocks + assert config.num_blocks == sum(pool.num_blocks for pool in config.pool_configs) + + accounting_page_sizes = { + group.kv_cache_spec.accounting_page_size_bytes + for group in config.kv_cache_groups + } + physical_page_sizes = { + group.kv_cache_spec.physical_page_size_bytes for group in config.kv_cache_groups + } + assert len(accounting_page_sizes) == 1 + accounting_page_size = accounting_page_sizes.pop() + assert pool_config.accounting_page_size_bytes == accounting_page_size + if len(physical_page_sizes) == 1: + assert pool_config.physical_page_size_bytes == physical_page_sizes.pop() + else: + assert pool_config.physical_page_size_bytes == accounting_page_size + + @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_none_hash(monkeypatch, hash_fn): import vllm.v1.core.kv_cache_utils @@ -1742,16 +1809,35 @@ def test_get_kv_cache_config_one_worker(): ], ) - # different hidden size that cannot be aligned by using different block size + # Different hidden size and different type that cannot be aligned by using + # different block size. This is supported by padding the smaller page. kv_cache_specs_hybrid = { "layer_1": new_kv_cache_spec(head_size=64), "layer_2": new_sliding_window_spec(head_size=96), } - with pytest.raises(NotImplementedError): - get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] - )[0] + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=42, + kv_cache_tensors=[ + KVCacheTensor( + size=mem_per_block_per_layer * 3 // 2 * 42, + shared_by=["layer_1", "layer_2"], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_1"], + new_kv_cache_spec( + head_size=64, + page_size_padded=mem_per_block_per_layer * 3 // 2, + ), + ), + KVCacheGroupSpec(["layer_2"], new_sliding_window_spec(head_size=96)), + ], + ) # Test num_gpu_blocks_override vllm_config.cache_config.num_gpu_blocks_override = 16 @@ -1779,6 +1865,435 @@ def test_get_kv_cache_configs_attention_free(): kv_cache_groups=[], ) ] + assert_legacy_single_pool_metadata(kv_cache_configs[0]) + + +def test_kv_cache_config_legacy_pool_metadata_single_group(): + spec = new_kv_cache_spec() + config = KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=spec.page_size_bytes * 10, shared_by=["layer_1"]), + KVCacheTensor(size=spec.page_size_bytes * 10, shared_by=["layer_2"]), + ], + kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], spec)], + ) + + assert_legacy_single_pool_metadata(config) + + +def test_kv_cache_config_legacy_pool_metadata_multi_group(): + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config) + mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 + kv_cache_specs = { + "layer_1": new_kv_cache_spec(), + "layer_2": new_sliding_window_spec(), + } + + config = get_kv_cache_configs( + vllm_config, [kv_cache_specs], [mem_per_block_per_layer * 2 * 32] + )[0] + + assert len(config.kv_cache_groups) == 2 + assert_legacy_single_pool_metadata(config) + + +def test_kv_cache_config_legacy_pool_metadata_mixed_physical_page_sizes(): + unpadded_mamba_spec = new_mamba_spec(mamba_cache_mode="all") + unified_page_size = unpadded_mamba_spec.physical_page_size_bytes + 1024 + attention_spec = new_kv_cache_spec(page_size_padded=unified_page_size) + mamba_spec = new_mamba_spec( + page_size_padded=unified_page_size, + mamba_cache_mode="all", + ) + assert ( + attention_spec.physical_page_size_bytes != mamba_spec.physical_page_size_bytes + ) + + config = KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1"], attention_spec), + KVCacheGroupSpec(["layer_2"], mamba_spec), + ], + ) + + assert_legacy_single_pool_metadata(config) + + +def test_kv_cache_config_pool_metadata_tracks_worker_min_blocks(): + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config) + spec = new_kv_cache_spec() + kv_cache_specs = [ + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + ] + + kv_cache_configs = get_kv_cache_configs( + vllm_config, + kv_cache_specs, + [ + spec.page_size_bytes * 2 * 10, + spec.page_size_bytes * 2 * 20, + ], + ) + + for config in kv_cache_configs: + assert config.num_blocks == 10 + assert_legacy_single_pool_metadata(config) + + +def make_request_constant_vllm_config( + max_model_len: int = 16, + max_num_seqs: int = 4, +) -> VllmConfig: + model_config = ModelConfig(max_model_len=max_model_len) + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ) + vllm_config = VllmConfig( + model_config=model_config, + scheduler_config=scheduler_config, + ) + vllm_config.cache_config.enable_prefix_caching = False + return vllm_config + + +def test_real_mamba_spec_none_mode_is_request_constant(): + spec = new_mamba_spec( + mamba_cache_mode="none", + num_speculative_blocks=2, + ) + + assert spec.memory_model == MemoryModel.REQUEST_CONSTANT + assert spec.blocks_per_request == 3 + assert spec.physical_page_size_bytes == spec.page_size_bytes + + +def test_real_mamba_spec_align_mode_blocks_per_request(): + spec = new_mamba_spec( + mamba_cache_mode="align", + num_speculative_blocks=2, + ) + + assert spec.memory_model == MemoryModel.REQUEST_CONSTANT + assert spec.blocks_per_request == 4 + + +def test_real_mamba_spec_all_mode_is_token_proportional(): + spec = new_mamba_spec( + mamba_cache_mode="all", + num_speculative_blocks=2, + ) + + assert spec.memory_model == MemoryModel.TOKEN_PROPORTIONAL + assert spec.blocks_per_request == 3 + + +def test_hybrid_qwen_like_config_generates_multi_pool(): + vllm_config = make_request_constant_vllm_config(max_num_seqs=4) + attention_spec = new_kv_cache_spec( + block_size=4, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ) + mamba_spec = new_mamba_spec( + block_size=4, + shapes=((4,),), + dtypes=(torch.float32,), + num_speculative_blocks=0, + mamba_cache_mode="none", + page_size_padded=attention_spec.page_size_bytes, + ) + mamba_num_blocks = 4 * mamba_spec.blocks_per_request + 1 + mamba_reserved_bytes = mamba_num_blocks * mamba_spec.physical_page_size_bytes + available_memory = mamba_reserved_bytes + attention_spec.page_size_bytes * 20 + + config = get_kv_cache_configs( + vllm_config, + [{"attn": attention_spec, "mamba": mamba_spec}], + [available_memory], + )[0] + + assert config.kv_cache_groups == [ + KVCacheGroupSpec(["attn"], attention_spec), + KVCacheGroupSpec(["mamba"], mamba_spec), + ] + assert config.group_to_pool_id == (0, 1) + assert config.pool_configs[0].memory_model == MemoryModel.TOKEN_PROPORTIONAL + assert config.pool_configs[0].num_blocks == 20 + assert config.pool_configs[1].memory_model == MemoryModel.REQUEST_CONSTANT + assert config.pool_configs[1].num_blocks == mamba_num_blocks + assert config.kv_cache_tensors == [ + KVCacheTensor(size=attention_spec.page_size_bytes * 20, shared_by=["attn"]), + KVCacheTensor(size=mamba_reserved_bytes, shared_by=["mamba"]), + ] + + +@pytest.mark.parametrize("enable_prefix_caching", [True, False]) +def test_real_mamba_spec_all_mode_keeps_shared_pool(enable_prefix_caching): + vllm_config = make_request_constant_vllm_config(max_num_seqs=4) + vllm_config.cache_config.enable_prefix_caching = enable_prefix_caching + attention_spec = new_kv_cache_spec( + block_size=4, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + page_size_padded=64, + ) + mamba_spec = new_mamba_spec( + block_size=4, + shapes=((4,),), + dtypes=(torch.float32,), + mamba_cache_mode="all", + page_size_padded=64, + ) + + config = get_kv_cache_configs( + vllm_config, + [{"attn": attention_spec, "mamba": mamba_spec}], + [64 * 10], + )[0] + + assert config.group_to_pool_id == (0, 0) + assert len(config.pool_configs) == 1 + assert config.pool_configs[0].memory_model == MemoryModel.TOKEN_PROPORTIONAL + + +def test_token_proportional_capacity_ignores_request_constant_pool(): + vllm_config = make_request_constant_vllm_config( + max_model_len=16, + max_num_seqs=4, + ) + attention_spec = new_kv_cache_spec( + block_size=4, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ) + mamba_spec = new_mamba_spec( + block_size=4, + shapes=((4,),), + dtypes=(torch.float32,), + mamba_cache_mode="none", + page_size_padded=attention_spec.page_size_bytes, + ) + mamba_num_blocks = 4 * mamba_spec.blocks_per_request + 1 + mamba_reserved_bytes = mamba_num_blocks * mamba_spec.physical_page_size_bytes + + config = get_kv_cache_configs( + vllm_config, + [{"attn": attention_spec, "mamba": mamba_spec}], + [mamba_reserved_bytes + attention_spec.page_size_bytes * 16], + )[0] + + assert get_token_proportional_kv_cache_capacity_tokens(config) == 64 + assert get_max_concurrency_for_kv_cache_config(vllm_config, config) == 4 + + +def test_request_constant_mamba_full_cudagraph_fails_closed(): + vllm_config = make_request_constant_vllm_config(max_num_seqs=4) + attention_spec = new_kv_cache_spec( + block_size=4, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ) + mamba_spec = new_mamba_spec( + block_size=4, + shapes=((4,),), + dtypes=(torch.float32,), + mamba_cache_mode="none", + page_size_padded=attention_spec.page_size_bytes, + ) + mamba_num_blocks = 4 * mamba_spec.blocks_per_request + 1 + mamba_reserved_bytes = mamba_num_blocks * mamba_spec.physical_page_size_bytes + kv_cache_config = get_kv_cache_configs( + vllm_config, + [{"attn": attention_spec, "mamba": mamba_spec}], + [mamba_reserved_bytes + attention_spec.page_size_bytes * 16], + )[0] + compilation_config = CompilationConfig(cudagraph_mode=CUDAGraphMode.FULL) + + with pytest.raises( + ValueError, + match="Full CUDA graph capture with REQUEST_CONSTANT KV cache", + ): + compilation_config.resolve_cudagraph_mode_and_sizes( + min_cg_support=AttentionCGSupport.ALWAYS, + min_cg_attn_backend="test", + kv_cache_config=kv_cache_config, + max_num_reqs=4, + ) + + +def test_mixed_memory_model_config_reserves_request_constant_pool(): + vllm_config = make_request_constant_vllm_config(max_num_seqs=4) + attention_spec = new_kv_cache_spec( + block_size=4, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ) + request_constant_spec = new_request_constant_spec( + num_speculative_blocks=1, + page_size_padded=16, + ) + request_constant_num_blocks = 4 * request_constant_spec.blocks_per_request + 1 + reserved_bytes = ( + request_constant_num_blocks * request_constant_spec.physical_page_size_bytes + ) + available_memory = reserved_bytes + attention_spec.page_size_bytes * 10 + + config = get_kv_cache_configs( + vllm_config, + [{"attn": attention_spec, "state": request_constant_spec}], + [available_memory], + )[0] + + assert config.num_blocks == 10 + assert config.kv_cache_groups == [ + KVCacheGroupSpec(["attn"], attention_spec), + KVCacheGroupSpec(["state"], request_constant_spec), + ] + assert config.kv_cache_tensors == [ + KVCacheTensor(size=attention_spec.page_size_bytes * 10, shared_by=["attn"]), + KVCacheTensor(size=reserved_bytes, shared_by=["state"]), + ] + assert config.group_to_pool_id == (0, 1) + assert config.pool_configs == ( + KVCachePoolConfig( + pool_id=0, + memory_model=MemoryModel.TOKEN_PROPORTIONAL, + group_ids=(0,), + num_blocks=10, + accounting_page_size_bytes=attention_spec.accounting_page_size_bytes, + physical_page_size_bytes=attention_spec.physical_page_size_bytes, + ), + KVCachePoolConfig( + pool_id=1, + memory_model=MemoryModel.REQUEST_CONSTANT, + group_ids=(1,), + num_blocks=request_constant_num_blocks, + accounting_page_size_bytes=( + request_constant_spec.accounting_page_size_bytes + ), + physical_page_size_bytes=request_constant_spec.physical_page_size_bytes, + ), + ) + + +def test_multi_pool_config_deterministic_across_workers(): + vllm_config = make_request_constant_vllm_config(max_num_seqs=4) + attention_spec = new_kv_cache_spec( + block_size=4, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ) + request_constant_spec = new_request_constant_spec( + num_speculative_blocks=1, + page_size_padded=16, + ) + request_constant_num_blocks = 4 * request_constant_spec.blocks_per_request + 1 + reserved_bytes = ( + request_constant_num_blocks * request_constant_spec.physical_page_size_bytes + ) + + kv_cache_configs = get_kv_cache_configs( + vllm_config, + [ + {"attn": attention_spec, "state": request_constant_spec}, + {"attn": attention_spec, "state": request_constant_spec}, + ], + [ + reserved_bytes + attention_spec.page_size_bytes * 10, + reserved_bytes + attention_spec.page_size_bytes * 20, + ], + ) + + for config in kv_cache_configs: + assert config.num_blocks == 10 + assert config.pool_configs[0].num_blocks == 10 + assert config.pool_configs[1].num_blocks == request_constant_num_blocks + assert config.kv_cache_tensors == [ + KVCacheTensor(size=attention_spec.page_size_bytes * 10, shared_by=["attn"]), + KVCacheTensor(size=reserved_bytes, shared_by=["state"]), + ] + + scheduler_config = generate_scheduler_kv_cache_config(kv_cache_configs) + assert scheduler_config.pool_configs == kv_cache_configs[0].pool_configs + assert scheduler_config.group_to_pool_id == (0, 1) + + +def test_request_constant_prefix_caching_fails_early(): + vllm_config = make_request_constant_vllm_config() + vllm_config.cache_config.enable_prefix_caching = True + attention_spec = new_kv_cache_spec( + block_size=4, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ) + request_constant_spec = new_request_constant_spec() + reserved_bytes = ( + vllm_config.scheduler_config.max_num_seqs + * request_constant_spec.blocks_per_request + + 1 + ) * request_constant_spec.physical_page_size_bytes + + with pytest.raises( + NotImplementedError, + match="Prefix caching with REQUEST_CONSTANT groups", + ): + get_kv_cache_configs( + vllm_config, + [{"attn": attention_spec, "state": request_constant_spec}], + [reserved_bytes + attention_spec.page_size_bytes * 10], + ) + + +def test_request_constant_reservation_fails_closed_when_memory_exhausted(): + vllm_config = make_request_constant_vllm_config(max_num_seqs=4) + attention_spec = new_kv_cache_spec( + block_size=4, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ) + request_constant_spec = new_request_constant_spec( + num_speculative_blocks=1, + page_size_padded=16, + ) + request_constant_num_blocks = 4 * request_constant_spec.blocks_per_request + 1 + reserved_bytes = ( + request_constant_num_blocks * request_constant_spec.physical_page_size_bytes + ) + + with pytest.raises( + ValueError, + match="REQUEST_CONSTANT KV cache reservation", + ): + kv_cache_utils.get_kv_cache_config_from_groups( + vllm_config, + [ + KVCacheGroupSpec(["attn"], attention_spec), + KVCacheGroupSpec(["state"], request_constant_spec), + ], + available_memory=reserved_bytes, + ) def test_generate_uniform_type_kv_cache_specs(): @@ -1854,6 +2369,36 @@ def test_generate_scheduler_kv_cache_config(): ) +def test_generate_scheduler_kv_cache_config_rejects_mismatched_pool_schema(): + spec = new_kv_cache_spec() + kv_cache_config = KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[], + kv_cache_groups=[KVCacheGroupSpec(["layer_1"], spec)], + ) + mismatched_kv_cache_config = KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[], + kv_cache_groups=[KVCacheGroupSpec(["layer_1"], spec)], + pool_configs=( + KVCachePoolConfig( + pool_id=0, + memory_model=MemoryModel.TOKEN_PROPORTIONAL, + group_ids=(0,), + num_blocks=10, + accounting_page_size_bytes=spec.accounting_page_size_bytes + 1, + physical_page_size_bytes=spec.physical_page_size_bytes, + ), + ), + group_to_pool_id=(0,), + ) + + with pytest.raises(AssertionError): + generate_scheduler_kv_cache_config( + [kv_cache_config, mismatched_kv_cache_config] + ) + + def new_mla_spec(cache_dtype_str=None): # head_size = kv_lora_rank(512) + qk_rope_head_dim(64) = 576 return MLAAttentionSpec( @@ -2039,7 +2584,16 @@ def test_auto_fit_max_model_len_with_hybrid(): model_config = ModelConfig(max_model_len=8192) # Simulate the user passing -1 by setting original_max_model_len model_config.original_max_model_len = -1 - vllm_config = VllmConfig(model_config=model_config) + scheduler_config = SchedulerConfig( + max_num_seqs=1, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ) + vllm_config = VllmConfig( + model_config=model_config, + scheduler_config=scheduler_config, + ) + vllm_config.cache_config.enable_prefix_caching = False mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 # 16KB per block per layer gamma = 2 @@ -2048,7 +2602,9 @@ def test_auto_fit_max_model_len_with_hybrid(): "layer_2": new_kv_cache_spec(), } - available_memory = mem_per_block_per_layer * (1024 // 16 + 1 + gamma) + # 64 attention blocks for 1024 tokens plus 3 compact Mamba blocks for the + # single request and 1 compact-pool null block. + available_memory = mem_per_block_per_layer * (1024 // 16 + 1 + gamma + 1) _kv_cache_configs = get_kv_cache_configs( vllm_config, [kv_cache_specs], [available_memory] ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index c35c38911a1a..b21093dee721 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -35,7 +35,9 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, + KVCachePoolConfig, MambaSpec, + MemoryModel, SlidingWindowSpec, ) @@ -125,8 +127,9 @@ def make_kv_cache_config_hybrid_model( elif second_spec_type == "mamba": second_spec = MambaSpec( block_size=block_size, - shapes=(1, 1), + shapes=((1, 1),), dtypes=(torch.float32,), + page_size_padded=8 * block_size, ) return KVCacheConfig( @@ -160,8 +163,9 @@ def make_kv_cache_config_three_types( if third_spec_type == "mamba": third_spec = MambaSpec( block_size=block_size, - shapes=(1, 1), + shapes=((1, 1),), dtypes=(torch.float32,), + page_size_padded=8 * block_size, ) elif third_spec_type == "sliding_window": third_spec = SlidingWindowSpec( @@ -740,13 +744,19 @@ def _make_hybrid_kv_cache_config( ), "mamba": lambda: MambaSpec( block_size=block_size, - shapes=(1, 1), + shapes=((1, 1),), dtypes=(torch.float32,), + page_size_padded=8 * block_size, + # Prefix-caching tests exercise the legacy shared-pool Mamba path. + # Non-"all" Mamba modes are REQUEST_CONSTANT and intentionally + # fail-closed with prefix caching. + mamba_cache_mode="all", ), "mamba_align": lambda: MambaSpec( block_size=block_size, - shapes=(1, 1), + shapes=((1, 1),), dtypes=(torch.float32,), + page_size_padded=8 * block_size, mamba_cache_mode="align", ), } @@ -967,44 +977,61 @@ def test_prefill_hybrid_model_combinations_eagle( manager.free(req1) -def test_prefill_hybrid_model_mamba_align(): - """Test that MambaManager.cache_blocks() handles null blocks in align mode. - - Regression test for https://github.com/vllm-project/vllm/issues/34361. - In mamba_cache_mode="align", allocate_new_blocks() pads req_to_blocks with - null blocks. cache_full_blocks() correctly skips them, but - MambaManager.cache_blocks() must also skip null blocks when tracking - cached_blocks_this_step. - """ +def test_prefill_hybrid_model_mamba_align_prefix_caching_rejected(): + """mamba_cache_mode="align" is REQUEST_CONSTANT and not prefix-cacheable.""" block_size = 16 num_blocks = 30 - - kv_cache_config = _make_hybrid_kv_cache_config( - block_size, num_blocks, ["full", "mamba_align"] + full_spec = FullAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, ) - manager = KVCacheManager( - kv_cache_config, - max_model_len=8192, - enable_caching=True, - hash_block_size=block_size, + mamba_spec = MambaSpec( + block_size=block_size, + shapes=((1, 1),), + dtypes=(torch.float32,), + page_size_padded=8 * block_size, + mamba_cache_mode="align", ) - - hash_fn = sha256 - - # 3 full blocks (48 tokens) + 7 partial tokens = 55 tokens total - all_token_ids = [i for i in range(3) for _ in range(block_size)] + [3] * 7 - - # First request: allocate_slots should not crash with the assertion error - # in MambaManager.cache_blocks() when null blocks are present. - req0 = make_request("0", all_token_ids, block_size, hash_fn) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert num_computed_tokens == 0 - - blocks = manager.allocate_slots(req0, 55, num_computed_tokens, computed_blocks) - assert blocks is not None - assert len(blocks.get_block_ids()) == 2 # full_attn + mamba groups - - manager.free(req0) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(["layer0"], full_spec), + KVCacheGroupSpec(["layer1"], mamba_spec), + ], + pool_configs=( + KVCachePoolConfig( + pool_id=0, + memory_model=MemoryModel.TOKEN_PROPORTIONAL, + group_ids=(0,), + num_blocks=num_blocks, + accounting_page_size_bytes=full_spec.accounting_page_size_bytes, + physical_page_size_bytes=full_spec.physical_page_size_bytes, + ), + KVCachePoolConfig( + pool_id=1, + memory_model=MemoryModel.REQUEST_CONSTANT, + group_ids=(1,), + num_blocks=3, + accounting_page_size_bytes=mamba_spec.accounting_page_size_bytes, + physical_page_size_bytes=mamba_spec.physical_page_size_bytes, + ), + ), + group_to_pool_id=(0, 1), + ) + + with pytest.raises( + NotImplementedError, + match="multi-pool configs only when prefix caching is disabled", + ): + KVCacheManager( + kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + ) def test_prefill_plp(): diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 42f4825e2b3b..877032cfa530 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -31,6 +31,9 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, + KVCachePoolConfig, + MambaSpec, + MemoryModel, ) from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -74,6 +77,96 @@ def test_get_num_unfinished_requests(): assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1 +def test_routed_experts_capacity_uses_token_proportional_pool(monkeypatch): + block_size = 16 + num_attention_blocks = 8 + model_config = ModelConfig( + model="facebook/opt-125m", + trust_remote_code=True, + dtype="float16", + seed=42, + skip_tokenizer_init=True, + ) + model_config.enable_return_routed_experts = True + scheduler_config = SchedulerConfig( + max_num_seqs=1, + max_num_batched_tokens=128, + max_model_len=128, + is_encoder_decoder=model_config.is_encoder_decoder, + ) + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + cache_dtype="auto", + enable_prefix_caching=False, + ) + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + ) + attention_spec = FullAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ) + mamba_spec = MambaSpec( + block_size=block_size, + shapes=((4,),), + dtypes=(torch.float32,), + mamba_cache_mode="none", + ) + kv_cache_config = KVCacheConfig( + num_blocks=num_attention_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(["attn"], attention_spec), + KVCacheGroupSpec(["mamba"], mamba_spec), + ], + pool_configs=( + KVCachePoolConfig( + pool_id=0, + memory_model=MemoryModel.TOKEN_PROPORTIONAL, + group_ids=(0,), + num_blocks=num_attention_blocks, + accounting_page_size_bytes=attention_spec.page_size_bytes, + physical_page_size_bytes=attention_spec.physical_page_size_bytes, + ), + KVCachePoolConfig( + pool_id=1, + memory_model=MemoryModel.REQUEST_CONSTANT, + group_ids=(1,), + num_blocks=2, + accounting_page_size_bytes=mamba_spec.page_size_bytes, + physical_page_size_bytes=mamba_spec.physical_page_size_bytes, + ), + ), + group_to_pool_id=(0, 1), + ) + cache_config.num_gpu_blocks = num_attention_blocks + routed_experts_reader = Mock() + monkeypatch.setattr( + "vllm.v1.core.sched.scheduler.RoutedExpertsReader.create", + lambda: routed_experts_reader, + ) + + scheduler = Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + block_size=block_size, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + expected_max_num_kv_tokens = num_attention_blocks * block_size + assert scheduler.routed_experts_attn_gid == 0 + assert scheduler.max_num_kv_tokens == expected_max_num_kv_tokens + routed_experts_reader.attach_buffer.assert_called_once_with( + max_num_kv_tokens=expected_max_num_kv_tokens, + vllm_config=vllm_config, + ) + + @pytest.mark.parametrize( "enable_prefix_caching, prompt_logprobs", [ diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index f59830dcd741..cafaf849f24b 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -6,7 +6,7 @@ import pytest import torch -from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.block_pool import BlockPool, CompactBlockPool from vllm.v1.core.kv_cache_utils import ( BlockHash, KVCacheBlock, @@ -14,9 +14,17 @@ ) from vllm.v1.core.single_type_kv_cache_manager import ( ChunkedLocalAttentionManager, + FullAttentionManager, + MambaManager, SlidingWindowManager, ) -from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, SlidingWindowSpec +from vllm.v1.kv_cache_interface import ( + ChunkedLocalAttentionSpec, + FullAttentionSpec, + MambaSpec, + SlidingWindowSpec, + TQFullAttentionSpec, +) pytestmark = pytest.mark.cpu_test @@ -44,6 +52,205 @@ def get_chunked_local_attention_manager( ) +@pytest.mark.parametrize( + ("kv_cache_spec", "manager_cls", "should_record"), + [ + ( + FullAttentionSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), + FullAttentionManager, + True, + ), + ( + TQFullAttentionSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + tq_slot_size=1, + ), + FullAttentionManager, + True, + ), + ( + SlidingWindowSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=4, + ), + SlidingWindowManager, + False, + ), + ( + ChunkedLocalAttentionSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + attention_chunk_size=4, + ), + ChunkedLocalAttentionManager, + False, + ), + ( + MambaSpec( + block_size=2, + shapes=((1,),), + dtypes=(torch.float32,), + ), + MambaManager, + False, + ), + ], +) +def test_legacy_new_block_ids_for_zeroing_behavior( + kv_cache_spec, manager_cls, should_record +): + block_pool = BlockPool( + num_gpu_blocks=10, + enable_caching=False, + hash_block_size=kv_cache_spec.block_size, + ) + manager = manager_cls( + kv_cache_spec, + block_pool=block_pool, + enable_caching=False, + kv_cache_group_id=0, + ) + + new_blocks = manager.allocate_new_blocks( + request_id="request", + num_tokens=4, + num_tokens_main_model=4, + ) + + if should_record: + assert manager.take_new_block_ids() == [b.block_id for b in new_blocks] + else: + assert manager.take_new_block_ids() == [] + assert manager.take_new_block_ids() == [] + + +def test_mamba_manager_accepts_allocation_only_pool_when_caching_disabled(): + kv_cache_spec = MambaSpec( + block_size=2, + shapes=((1,),), + dtypes=(torch.float32,), + ) + block_pool = CompactBlockPool(num_allocatable=2) + manager = MambaManager( + kv_cache_spec, + block_pool=block_pool, + enable_caching=False, + kv_cache_group_id=0, + ) + + new_blocks = manager.allocate_new_blocks( + request_id="request", + num_tokens=4, + num_tokens_main_model=4, + ) + + assert [block.block_id for block in new_blocks] == [2] + assert block_pool.get_num_free_blocks() == 1 + manager.free("request") + assert block_pool.get_num_free_blocks() == 2 + + +def test_mamba_manager_request_constant_none_allocates_once(): + kv_cache_spec = MambaSpec( + block_size=2, + shapes=((1,),), + dtypes=(torch.float32,), + mamba_cache_mode="none", + num_speculative_blocks=1, + ) + block_pool = CompactBlockPool(num_allocatable=4) + manager = MambaManager( + kv_cache_spec, + block_pool=block_pool, + enable_caching=False, + kv_cache_group_id=0, + ) + + assert ( + manager.get_num_blocks_to_allocate( + "request", + num_tokens=100, + new_computed_blocks=[], + total_computed_tokens=0, + num_tokens_main_model=100, + ) + == kv_cache_spec.blocks_per_request + ) + new_blocks = manager.allocate_new_blocks( + request_id="request", + num_tokens=100, + num_tokens_main_model=100, + ) + + assert len(new_blocks) == kv_cache_spec.blocks_per_request + assert all(block.block_id != 0 for block in new_blocks) + assert block_pool.get_num_free_blocks() == 2 + assert ( + manager.get_num_blocks_to_allocate( + "request", + num_tokens=200, + new_computed_blocks=[], + total_computed_tokens=100, + num_tokens_main_model=200, + ) + == 0 + ) + assert ( + manager.allocate_new_blocks( + request_id="request", + num_tokens=200, + num_tokens_main_model=200, + ) + == [] + ) + + manager.remove_skipped_blocks("request", num_computed_tokens=100) + assert block_pool.get_num_free_blocks() == 2 + manager.free("request") + assert block_pool.get_num_free_blocks() == 4 + + +def test_mamba_manager_request_constant_align_free_filters_null_blocks(): + kv_cache_spec = MambaSpec( + block_size=2, + shapes=((1,),), + dtypes=(torch.float32,), + mamba_cache_mode="align", + ) + block_pool = CompactBlockPool(num_allocatable=4) + manager = MambaManager( + kv_cache_spec, + block_pool=block_pool, + enable_caching=False, + kv_cache_group_id=0, + ) + + new_blocks = manager.allocate_new_blocks( + request_id="request", + num_tokens=6, + num_tokens_main_model=6, + ) + + assert len(new_blocks) == 3 + assert sum(not block.is_null for block in manager.req_to_blocks["request"]) == 1 + assert block_pool.get_num_free_blocks() == 3 + manager.free("request") + assert block_pool.get_num_free_blocks() == 4 + + def test_chunked_local_attention_possible_cached_prefix(): block_size = 2 chunked_local_attention_spec = ChunkedLocalAttentionSpec( diff --git a/tests/v1/simple_kv_offload/test_scheduler.py b/tests/v1/simple_kv_offload/test_scheduler.py index 132f52fe3b36..8d37aab4428c 100644 --- a/tests/v1/simple_kv_offload/test_scheduler.py +++ b/tests/v1/simple_kv_offload/test_scheduler.py @@ -6,6 +6,7 @@ from dataclasses import dataclass +import pytest import torch from vllm import SamplingParams @@ -34,7 +35,10 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, + KVCachePoolConfig, KVCacheTensor, + MambaSpec, + MemoryModel, ) from vllm.v1.outputs import KVConnectorOutput from vllm.v1.request import Request @@ -94,6 +98,59 @@ def _make_kv_cache_config( ) +def _make_request_constant_kv_cache_config(num_blocks: int = 8) -> KVCacheConfig: + attention_spec = FullAttentionSpec( + block_size=BLOCK_SIZE, + num_kv_heads=NUM_KV_HEADS, + head_size=HEAD_SIZE, + dtype=DTYPE, + ) + mamba_spec = MambaSpec( + block_size=BLOCK_SIZE, + shapes=((1,),), + dtypes=(DTYPE,), + mamba_cache_mode="none", + page_size_padded=attention_spec.page_size_bytes, + ) + mamba_num_blocks = 3 + return KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[ + KVCacheTensor( + size=attention_spec.page_size_bytes * num_blocks, + shared_by=["attn"], + ), + KVCacheTensor( + size=mamba_spec.physical_page_size_bytes * mamba_num_blocks, + shared_by=["mamba"], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["attn"], attention_spec), + KVCacheGroupSpec(["mamba"], mamba_spec), + ], + pool_configs=( + KVCachePoolConfig( + pool_id=0, + memory_model=MemoryModel.TOKEN_PROPORTIONAL, + group_ids=(0,), + num_blocks=num_blocks, + accounting_page_size_bytes=attention_spec.accounting_page_size_bytes, + physical_page_size_bytes=attention_spec.physical_page_size_bytes, + ), + KVCachePoolConfig( + pool_id=1, + memory_model=MemoryModel.REQUEST_CONSTANT, + group_ids=(1,), + num_blocks=mamba_num_blocks, + accounting_page_size_bytes=mamba_spec.accounting_page_size_bytes, + physical_page_size_bytes=mamba_spec.physical_page_size_bytes, + ), + ), + group_to_pool_id=(0, 1), + ) + + def _make_vllm_config(block_size: int = BLOCK_SIZE) -> VllmConfig: """Minimal VllmConfig for scheduler tests (no GPU).""" model_config = ModelConfig( @@ -173,6 +230,19 @@ def make_scheduler( ) +def test_cpu_offload_rejects_request_constant_kv_cache(): + kv_cache_config = _make_request_constant_kv_cache_config() + + with pytest.raises( + NotImplementedError, + match="CPU KV cache offload with REQUEST_CONSTANT specs", + ): + SimpleCPUOffloadScheduler._derive_cpu_config( + kv_cache_config, + cpu_capacity_bytes=_BYTES_PER_BLOCK, + ) + + _req_counter = 0 diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 0de443858c98..5632610d7fd0 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass from types import SimpleNamespace import numpy as np @@ -22,7 +23,9 @@ ) from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform +from vllm.platforms.interface import Platform from vllm.sampling_params import SamplingParams from vllm.utils.mem_constants import GiB_bytes from vllm.utils.system_utils import update_environment_variables @@ -37,10 +40,15 @@ KVCacheConfig, KVCacheGroupSpec, KVCacheTensor, + MambaSpec, + MemoryModel, ) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.kv_connector_model_runner_mixin import ( + KVConnectorModelRunnerMixin, +) from vllm.v1.worker.utils import AttentionGroup, select_common_block_size BLOCK_SIZE = 16 @@ -48,6 +56,134 @@ DEVICE_TYPE = current_platform.device_type +@dataclass(frozen=True) +class _RequestConstantMambaSpec(MambaSpec): + """Test-only Mamba-like request-constant spec for worker reshape paths.""" + + @property + def memory_model(self) -> MemoryModel: + return MemoryModel.REQUEST_CONSTANT + + +@dataclass(frozen=True, kw_only=True) +class _RequestConstantFullAttentionSpec(FullAttentionSpec): + """Test-only invalid attention spec used to verify fail-closed guards.""" + + @property + def memory_model(self) -> MemoryModel: + return MemoryModel.REQUEST_CONSTANT + + +class _TestAttentionBackend: + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, int, int, int, int]: + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_kv_cache_stride_order(): + return tuple(range(5)) + + +class _HybridBlockSizeTestBackend: + @staticmethod + def get_supported_kernel_block_sizes() -> list[int]: + return [16, 32, 64] + + @staticmethod + def get_name() -> str: + return "HYBRID_BLOCK_SIZE_TEST" + + +class _HybridBlockSizeTestModel: + @staticmethod + def get_mamba_state_shape_from_config(vllm_config): + return ((130,),) + + @staticmethod + def get_mamba_state_dtype_from_config(vllm_config): + return (torch.float32,) + + +class _HybridBlockSizeTestModelConfig: + dtype = torch.float16 + use_mla = False + architecture = "HybridBlockSizeTestModel" + + def get_num_kv_heads(self, parallel_config): + return 1 + + def get_head_size(self): + return 1 + + def get_mamba_chunk_size(self): + return 16 + + +def _make_hybrid_block_size_test_config( + *, + mamba_cache_mode: str, + block_size: int = 16, + mamba_page_size_padded: int | None = None, +): + cache_config = CacheConfig( + block_size=block_size, + cache_dtype="auto", + mamba_cache_mode=mamba_cache_mode, + ) + cache_config.mamba_page_size_padded = mamba_page_size_padded + return SimpleNamespace( + cache_config=cache_config, + model_config=_HybridBlockSizeTestModelConfig(), + parallel_config=ParallelConfig(), + ) + + +def _patch_hybrid_block_size_test_model(monkeypatch): + def resolve_model_cls(*args, **kwargs): + return _HybridBlockSizeTestModel, None + + monkeypatch.setattr(ModelRegistry, "resolve_model_cls", resolve_model_cls) + + +def _reshape_kv_cache_tensor_for_test( + kv_cache_spec, + raw_tensor: torch.Tensor, + layer_name: str = "layer.0", +): + group = AttentionGroup( + _TestAttentionBackend, + [layer_name], + kv_cache_spec, + 0, + ) + runner_stub = SimpleNamespace( + runner_only_attn_layers=set(), + cache_config=SimpleNamespace(cache_dtype="auto"), + _kv_cache_spec_attn_group_iterator=lambda: iter([group]), + ) + kv_cache_config = KVCacheConfig( + num_blocks=1, + kv_cache_tensors=[ + KVCacheTensor(size=raw_tensor.numel(), shared_by=[layer_name]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(layer_names=[layer_name], kv_cache_spec=kv_cache_spec) + ], + ) + return GPUModelRunner._reshape_kv_cache_tensors( + runner_stub, + kv_cache_config, + {layer_name: raw_tensor}, + [kv_cache_spec.block_size], + ) + + def initialize_kv_cache(runner: GPUModelRunner): """ Only perform necessary steps in GPUModelRunner.initialize_kv_cache() @@ -100,6 +236,7 @@ def get_vllm_config(): block_size=BLOCK_SIZE, gpu_memory_utilization=0.9, cache_dtype="auto", + mamba_cache_mode="all", ) parallel_config = ParallelConfig() vllm_config = VllmConfig( @@ -977,6 +1114,179 @@ def test_update_hybrid_attention_mamba_layout_with_num_block_2_rewrites_stride() which was ambiguous before get_kv_cache_block_dim was used""" +@pytest.mark.parametrize( + ( + "mamba_cache_mode", + "initial_mamba_block_size", + "initial_page_size_padded", + "expected_block_size", + "expected_mamba_block_size", + "expected_page_size_padded", + ), + [ + ("none", None, 999, 16, None, None), + ("align", 2048, 999, 16, 16, None), + ("all", None, None, 144, 144, 576), + ], + ids=["request_constant_none", "request_constant_align", "token_proportional_all"], +) +def test_request_constant_mamba_and_token_proportional_mamba_all_platform_padding( + monkeypatch, + mamba_cache_mode, + initial_mamba_block_size, + initial_page_size_padded, + expected_block_size, + expected_mamba_block_size, + expected_page_size_padded, +): + _patch_hybrid_block_size_test_model(monkeypatch) + vllm_config = _make_hybrid_block_size_test_config( + mamba_cache_mode=mamba_cache_mode, + mamba_page_size_padded=initial_page_size_padded, + ) + vllm_config.cache_config.mamba_block_size = initial_mamba_block_size + + Platform._align_hybrid_block_size(vllm_config, _HybridBlockSizeTestBackend) + + assert vllm_config.cache_config.block_size == expected_block_size + assert vllm_config.cache_config.mamba_block_size == expected_mamba_block_size + assert vllm_config.cache_config.mamba_page_size_padded == expected_page_size_padded + + +def test_reshape_request_constant_mamba_uses_physical_page(): + spec = _RequestConstantMambaSpec( + block_size=1, + shapes=((2,),), + dtypes=(torch.float32,), + page_size_padded=16, + ) + num_blocks = 3 + raw_tensor = torch.empty( + spec.physical_page_size_bytes * num_blocks, + dtype=torch.int8, + ) + + kv_caches = _reshape_kv_cache_tensor_for_test(spec, raw_tensor, "state") + + assert kv_caches["state"][0].shape == (num_blocks, 2) + + +def test_reshape_request_constant_mamba_stride_matches_physical(): + dtype = torch.float32 + spec = _RequestConstantMambaSpec( + block_size=1, + shapes=((2,),), + dtypes=(dtype,), + page_size_padded=16, + ) + raw_tensor = torch.empty( + spec.physical_page_size_bytes * 3, + dtype=torch.int8, + ) + + kv_caches = _reshape_kv_cache_tensor_for_test(spec, raw_tensor, "state") + state_tensor = kv_caches["state"][0] + dtype_size = torch.empty((), dtype=dtype).element_size() + + assert state_tensor.stride()[0] == spec.physical_page_size_bytes // dtype_size + assert state_tensor.stride()[0] != spec.page_size_bytes // dtype_size + + +def test_reshape_token_proportional_mamba_uses_padded_page(): + dtype = torch.float32 + spec = MambaSpec( + block_size=1, + shapes=((2,),), + dtypes=(dtype,), + page_size_padded=16, + mamba_cache_mode="all", + ) + num_blocks = 3 + raw_tensor = torch.empty( + spec.page_size_bytes * num_blocks, + dtype=torch.int8, + ) + + kv_caches = _reshape_kv_cache_tensor_for_test(spec, raw_tensor, "state") + state_tensor = kv_caches["state"][0] + dtype_size = torch.empty((), dtype=dtype).element_size() + + assert state_tensor.shape == (num_blocks, 2) + assert state_tensor.stride()[0] == spec.page_size_bytes // dtype_size + + +def test_reshape_token_proportional_attention_unchanged(): + spec = FullAttentionSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + page_size_padded=32, + ) + num_blocks = 2 + raw_tensor = torch.empty( + spec.page_size_bytes * num_blocks, + dtype=torch.int8, + ) + + kv_caches = _reshape_kv_cache_tensor_for_test(spec, raw_tensor, "attn") + kv_cache = kv_caches["attn"] + + assert kv_cache.shape[1] == num_blocks + assert ( + kv_cache.numel() + == raw_tensor.numel() // torch.empty((), dtype=spec.dtype).element_size() + ) + + +def test_reshape_attention_request_constant_rejected(): + spec = _RequestConstantFullAttentionSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ) + raw_tensor = torch.empty(spec.page_size_bytes, dtype=torch.int8) + + with pytest.raises(NotImplementedError, match="REQUEST_CONSTANT AttentionSpec"): + _reshape_kv_cache_tensor_for_test(spec, raw_tensor, "attn") + + +def test_reshape_connector_request_constant_rejected(): + spec = _RequestConstantFullAttentionSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ) + kv_cache_config = KVCacheConfig( + num_blocks=1, + kv_cache_tensors=[ + KVCacheTensor(size=spec.page_size_bytes, shared_by=["attn"]), + ], + kv_cache_groups=[KVCacheGroupSpec(layer_names=["attn"], kv_cache_spec=spec)], + ) + attn_groups = [ + [ + AttentionGroup( + _TestAttentionBackend, + ["attn"], + spec, + 0, + ) + ] + ] + + with pytest.raises(NotImplementedError, match="Cross-layer KV connector"): + KVConnectorModelRunnerMixin.allocate_uniform_kv_caches( + kv_cache_config, + attn_groups, + "auto", + torch.device("cpu"), + [spec.block_size], + ) + + def test_hybrid_block_table_initialization(): """Test hybrid block table with different kernel and kvcache_manager block sizes.""" diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 0f02a92681c1..343f25c6c197 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -1417,6 +1417,26 @@ def resolve_cudagraph_mode_and_sizes( tensor_parallel_size, ) + if ( + kv_cache_config is not None + and cudagraph_mode.has_full_cudagraphs() + and not is_profiling + and kv_cache_config.has_mamba_layers + ): + from vllm.v1.kv_cache_interface import MemoryModel + + if any( + pool.memory_model == MemoryModel.REQUEST_CONSTANT + for pool in kv_cache_config.pool_configs + ): + raise ValueError( + "Full CUDA graph capture with REQUEST_CONSTANT KV cache " + "(Mamba in 'none' or 'align' mode) is not yet supported. " + "Either disable cudagraph capture (e.g., enforce_eager=True) " + "or set mamba_cache_mode='all' to use the legacy " + "shared-pool path." + ) + # For Mamba models with FULL decode cudagraphs, each decode # sequence needs one Mamba cache block. The decode cudagraph # dispatcher already caps batch sizes at max_num_seqs, so we just diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 2753326755fb..5f771d6d07d3 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -575,6 +575,16 @@ def _align_hybrid_block_size( else None ) + if cache_config.mamba_cache_mode != "all": + # REQUEST_CONSTANT Mamba modes use a separate compact pool sized by + # the physical Mamba state page. They no longer require attention + # pages to be inflated to match Mamba pages, nor Mamba pages to be + # padded to the attention page size. + if cache_config.mamba_cache_mode == "align": + cache_config.mamba_block_size = cache_config.block_size + cache_config.mamba_page_size_padded = None + return + # Get kernel block alignment from the backend's supported sizes with set_current_vllm_config(vllm_config): kernel_block_alignment_size = max( diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index bd9006f3f8fc..613c4dce1f22 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -294,14 +294,23 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: new_block_size * attn_page_size_1_token ) cache_config.block_size = new_block_size - logger.info( - "[XPU]Setting attention block size to %d tokens to ensure multiple of %d, " - "set mamba_page_size_padded to %d bytes accordingly, before was %d bytes.", - new_block_size, - kernel_block_size, - cache_config.mamba_page_size_padded, - original_mamba_page_size_padded, - ) + if original_mamba_page_size_padded is None: + logger.info( + "[XPU]Setting attention block size to %d tokens to ensure " + "multiple of %d; mamba_page_size_padded remains unset.", + new_block_size, + kernel_block_size, + ) + else: + logger.info( + "[XPU]Setting attention block size to %d tokens to ensure " + "multiple of %d, set mamba_page_size_padded to %d bytes " + "accordingly, before was %d bytes.", + new_block_size, + kernel_block_size, + cache_config.mamba_page_size_padded, + original_mamba_page_size_padded, + ) @classmethod def support_hybrid_kv_cache(cls) -> bool: diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 9097079ef33a..638944c0a14d 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Sequence -from typing import Any +from typing import Any, Protocol from vllm.distributed.kv_events import ( MEDIUM_GPU, @@ -31,6 +31,115 @@ logger = init_logger(__name__) +class BlockPoolProtocol(Protocol): + """Basic allocation contract shared by all KV cache block pools. + + Implementations must reserve ``block_id=0`` as the null sentinel. Allocation + methods must never return the null block, and backing tensors must be sized + so that block index 0 is always valid. + """ + + num_gpu_blocks: int + null_block: KVCacheBlock + + def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: ... + + def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: ... + + def get_num_free_blocks(self) -> int: ... + + def get_usage(self) -> float: ... + + +class CacheableBlockPoolProtocol(BlockPoolProtocol, Protocol): + """Block-pool contract for pools that also support prefix caching.""" + + def get_cached_block( + self, block_hash: BlockHash, kv_cache_group_ids: list[int] + ) -> list[KVCacheBlock] | None: ... + + def cache_full_blocks( + self, + request: Request, + blocks: list[KVCacheBlock], + num_cached_blocks: int, + num_full_blocks: int, + block_size: int, + kv_cache_group_id: int, + ) -> None: ... + + def touch(self, blocks: Sequence[KVCacheBlock]) -> None: ... + + def evict_blocks(self, block_ids: set[int]) -> None: ... + + def reset_prefix_cache(self) -> bool: ... + + def take_events(self) -> list[KVCacheEvent]: ... + + +class CompactBlockPool: + """Compact allocation-only pool for request-constant KV-cache specs. + + Invariants enforced by this pool: + (1) ``block_id=0`` is reserved as the null sentinel. + (2) Backing tensors must be sized for ``num_allocatable + 1`` blocks. + (3) Allocation never returns the null block. + (4) ``ref_cnt`` is binary: 0 when free, 1 when allocated. + (5) ``get_new_blocks(0)`` returns ``[]`` and ``free_blocks([])`` is a no-op. + (6) Freeing a null block or a non-allocated block is rejected. + + This pool deliberately does not implement prefix-caching APIs such as + ``touch``, ``get_cached_block``, or ``cache_full_blocks``. + """ + + def __init__(self, num_allocatable: int) -> None: + assert isinstance(num_allocatable, int) and num_allocatable >= 0 + self._num_allocatable = num_allocatable + self.num_gpu_blocks = num_allocatable + 1 + self.null_block = KVCacheBlock(block_id=0) + self.null_block.is_null = True + self._free: list[KVCacheBlock] = [ + KVCacheBlock(block_id=i) for i in range(1, num_allocatable + 1) + ] + + def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: + if num_blocks == 0: + return [] + if num_blocks > self.get_num_free_blocks(): + raise ValueError(f"Cannot get {num_blocks} free blocks from the pool") + + blocks = [self._free.pop() for _ in range(num_blocks)] + for block in blocks: + assert block.block_id != 0 + assert not block.is_null + assert block.ref_cnt == 0 + block.ref_cnt += 1 + assert block.ref_cnt == 1 + return blocks + + def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: + blocks = list(ordered_blocks) + if not blocks: + return + for block in blocks: + assert block.block_id != 0, "null block must never be freed" + assert not block.is_null, "null block must never be freed" + assert block.ref_cnt == 1, ( + "CompactBlockPool expects binary ref_cnt semantics" + ) + block.ref_cnt -= 1 + assert block.ref_cnt == 0 + self._free.extend(blocks) + + def get_num_free_blocks(self) -> int: + return len(self._free) + + def get_usage(self) -> float: + if self._num_allocatable == 0: + return 0 + return 1.0 - (self.get_num_free_blocks() / self._num_allocatable) + + class BlockHashToBlockMap: """ Cache of blocks that are used for prefix caching. It caches blocks diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 65993e804153..f0fd97f3b1e3 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -3,8 +3,14 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from math import lcm +from typing import cast -from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.block_pool import ( + BlockPool, + BlockPoolProtocol, + CacheableBlockPoolProtocol, + CompactBlockPool, +) from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector from vllm.v1.core.kv_cache_utils import ( BlockHash, @@ -20,7 +26,9 @@ from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheConfig, + KVCachePoolConfig, KVCacheSpec, + MemoryModel, ) from vllm.v1.request import Request @@ -47,13 +55,18 @@ def __init__( self.max_model_len = max_model_len self.enable_caching = enable_caching - self.block_pool = BlockPool( - kv_cache_config.num_blocks, - enable_caching, - hash_block_size, + self._check_pool_config_supported() + self._block_pools = self._make_block_pools( enable_kv_cache_events, + hash_block_size, metrics_collector, ) + # Public legacy alias. Existing consumers assume a single shared pool. + self.block_pool = self._block_pools[0] + self._group_to_pool = tuple( + self._block_pools[pool_id] + for pool_id in self.kv_cache_config.group_to_pool_id + ) # KV cache group indices that get the EAGLE last-block drop. self.eagle_group_ids: set[int] = { @@ -68,7 +81,7 @@ def __init__( kv_cache_spec=kv_cache_group.kv_cache_spec, max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len, - block_pool=self.block_pool, + block_pool=self._group_to_pool[i], enable_caching=enable_caching, kv_cache_group_id=i, dcp_world_size=dcp_world_size, @@ -77,6 +90,68 @@ def __init__( for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups) ) + def _check_pool_config_supported(self) -> None: + """Reject unsupported multi-pool prefix-cache configs.""" + pool_ids = set(self.kv_cache_config.group_to_pool_id) + num_distinct_pools = max( + len(self.kv_cache_config.pool_configs), + len(pool_ids), + ) + if num_distinct_pools > 1 and self.enable_caching: + raise NotImplementedError( + "KVCacheCoordinator currently supports multi-pool configs only " + "when prefix caching is disabled. Got " + f"{num_distinct_pools} distinct pools in kv_cache_config. " + "Pool-aware prefix-cache dispatch is not implemented yet." + ) + + def _make_block_pools( + self, + enable_kv_cache_events: bool, + hash_block_size: int, + metrics_collector: KVCacheMetricsCollector | None, + ) -> tuple[BlockPoolProtocol, ...]: + if not self.kv_cache_config.pool_configs: + return ( + BlockPool( + self.kv_cache_config.num_blocks, + self.enable_caching, + hash_block_size, + enable_kv_cache_events, + metrics_collector, + ), + ) + + return tuple( + self._make_block_pool( + pool_config, + enable_kv_cache_events, + hash_block_size, + metrics_collector, + ) + for pool_config in self.kv_cache_config.pool_configs + ) + + def _make_block_pool( + self, + pool_config: KVCachePoolConfig, + enable_kv_cache_events: bool, + hash_block_size: int, + metrics_collector: KVCacheMetricsCollector | None, + ) -> BlockPoolProtocol: + if pool_config.memory_model == MemoryModel.TOKEN_PROPORTIONAL: + return BlockPool( + pool_config.num_blocks, + self.enable_caching, + hash_block_size, + enable_kv_cache_events, + metrics_collector, + ) + if pool_config.memory_model == MemoryModel.REQUEST_CONSTANT: + assert not self.enable_caching + return CompactBlockPool(num_allocatable=pool_config.num_blocks - 1) + raise AssertionError(f"Unsupported KV cache memory model: {pool_config}") + def get_num_blocks_to_allocate( self, request_id: str, @@ -110,12 +185,35 @@ def get_num_blocks_to_allocate( Returns: The number of blocks to allocate. """ - num_blocks_to_allocate = 0 + return sum( + self.get_num_blocks_to_allocate_by_pool( + request_id, + num_tokens, + new_computed_blocks, + num_encoder_tokens, + total_computed_tokens, + num_tokens_main_model, + apply_admission_cap=apply_admission_cap, + ) + ) + + def get_num_blocks_to_allocate_by_pool( + self, + request_id: str, + num_tokens: int, + new_computed_blocks: tuple[Sequence[KVCacheBlock], ...], + num_encoder_tokens: int, + total_computed_tokens: int, + num_tokens_main_model: int, + apply_admission_cap: bool = False, + ) -> tuple[int, ...]: + """Get the number of blocks needed from each KV cache pool.""" + num_blocks_to_allocate = [0] * len(self.kv_cache_config.pool_configs) for i, manager in enumerate(self.single_type_managers): if isinstance(manager, CrossAttentionManager): # For cross-attention, we issue a single static allocation # of blocks based on the number of encoder input tokens. - num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + required_blocks = manager.get_num_blocks_to_allocate( request_id, num_encoder_tokens, [], @@ -124,7 +222,7 @@ def get_num_blocks_to_allocate( apply_admission_cap=apply_admission_cap, ) else: - num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + required_blocks = manager.get_num_blocks_to_allocate( request_id, num_tokens, new_computed_blocks[i], @@ -132,7 +230,16 @@ def get_num_blocks_to_allocate( num_tokens_main_model, apply_admission_cap=apply_admission_cap, ) - return num_blocks_to_allocate + pool_id = self.kv_cache_config.group_to_pool_id[i] + num_blocks_to_allocate[pool_id] += required_blocks + return tuple(num_blocks_to_allocate) + + def get_num_free_blocks_by_pool(self) -> tuple[int, ...]: + """Get the number of currently free blocks in each KV cache pool.""" + return tuple( + self._block_pools[pool_id].get_num_free_blocks() + for pool_id in range(len(self.kv_cache_config.pool_configs)) + ) def allocate_new_computed_blocks( self, @@ -375,11 +482,12 @@ def find_longest_cache_hit( block_hashes: list[BlockHash], max_cache_hit_length: int, ) -> tuple[tuple[list[KVCacheBlock], ...], int]: + block_pool = cast(CacheableBlockPoolProtocol, self.block_pool) hit_blocks = self.single_type_managers[0].find_longest_cache_hit( block_hashes=block_hashes, max_length=max_cache_hit_length, kv_cache_group_ids=[0], - block_pool=self.block_pool, + block_pool=block_pool, kv_cache_spec=self.kv_cache_spec, use_eagle=0 in self.eagle_group_ids, alignment_tokens=self.block_size, @@ -553,11 +661,12 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: _max_length = min( curr_hit_length + spec.block_size, max_cache_hit_length ) + block_pool = cast(CacheableBlockPoolProtocol, self.block_pool) hit_blocks = manager_cls.find_longest_cache_hit( block_hashes=_get_block_hashes(spec), max_length=_max_length, kv_cache_group_ids=group_ids, - block_pool=self.block_pool, + block_pool=block_pool, kv_cache_spec=spec, use_eagle=use_eagle, alignment_tokens=self.lcm_block_size, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 431776870cf4..3bf3dcbe6fc1 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -4,10 +4,11 @@ import itertools from collections.abc import Sequence from dataclasses import dataclass -from typing import Literal, overload +from typing import Literal, cast, overload from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger +from vllm.v1.core.block_pool import CacheableBlockPoolProtocol from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector from vllm.v1.core.kv_cache_utils import KVCacheBlock @@ -168,6 +169,10 @@ def usage(self) -> float: """ return self.block_pool.get_usage() + def _cacheable_block_pool(self) -> CacheableBlockPoolProtocol: + """Return the legacy shared pool as a prefix-cache-capable pool.""" + return cast(CacheableBlockPoolProtocol, self.block_pool) + def make_prefix_cache_stats(self) -> PrefixCacheStats | None: """Get (and reset) the prefix cache stats. @@ -180,6 +185,38 @@ def make_prefix_cache_stats(self) -> PrefixCacheStats | None: self.prefix_cache_stats = PrefixCacheStats() return stats + def _has_enough_free_blocks_by_pool( + self, num_blocks_to_allocate_by_pool: tuple[int, ...] + ) -> bool: + num_free_blocks_by_pool = self.coordinator.get_num_free_blocks_by_pool() + assert len(num_blocks_to_allocate_by_pool) == len(num_free_blocks_by_pool) + return all( + num_blocks_to_allocate <= num_free_blocks + for num_blocks_to_allocate, num_free_blocks in zip( + num_blocks_to_allocate_by_pool, num_free_blocks_by_pool + ) + ) + + def _get_num_blocks_to_allocate_by_pool( + self, + request: Request, + num_tokens: int, + new_computed_block_list: tuple[Sequence[KVCacheBlock], ...], + num_encoder_tokens: int, + total_computed_tokens: int, + num_tokens_main_model: int, + apply_admission_cap: bool = False, + ) -> tuple[int, ...]: + return self.coordinator.get_num_blocks_to_allocate_by_pool( + request_id=request.request_id, + num_tokens=num_tokens, + new_computed_blocks=new_computed_block_list, + num_encoder_tokens=num_encoder_tokens, + total_computed_tokens=total_computed_tokens, + num_tokens_main_model=num_tokens_main_model, + apply_admission_cap=apply_admission_cap, + ) + def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -336,16 +373,16 @@ def allocate_slots( # First check and fail if the full request sequence won't fit. full_num_tokens = min(request.num_tokens, self.max_model_len) - num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( - request_id=request.request_id, + num_blocks_to_allocate_by_pool = self._get_num_blocks_to_allocate_by_pool( + request, num_tokens=full_num_tokens, - new_computed_blocks=new_computed_block_list, + new_computed_block_list=new_computed_block_list, num_encoder_tokens=num_encoder_tokens, total_computed_tokens=total_computed_tokens, num_tokens_main_model=full_num_tokens, apply_admission_cap=True, ) - if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + if not self._has_enough_free_blocks_by_pool(num_blocks_to_allocate_by_pool): return None num_tokens_main_model = total_computed_tokens + num_new_tokens @@ -363,17 +400,17 @@ def allocate_slots( request.request_id, total_computed_tokens ) - num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( - request_id=request.request_id, + num_blocks_to_allocate_by_pool = self._get_num_blocks_to_allocate_by_pool( + request, num_tokens=num_tokens_need_slot, - new_computed_blocks=new_computed_block_list, + new_computed_block_list=new_computed_block_list, num_encoder_tokens=num_encoder_tokens, total_computed_tokens=num_local_computed_tokens + num_external_computed_tokens, num_tokens_main_model=num_tokens_main_model, ) - if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + if not self._has_enough_free_blocks_by_pool(num_blocks_to_allocate_by_pool): # Cannot allocate new blocks return None @@ -444,7 +481,7 @@ def evict_blocks(self, block_ids: set[int]) -> None: Args: block_ids: Set of block IDs to evict from cache. """ - self.block_pool.evict_blocks(block_ids) + self._cacheable_block_pool().evict_blocks(block_ids) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -455,7 +492,7 @@ def reset_prefix_cache(self) -> bool: bool: True if the prefix cache is successfully reset, False otherwise. """ - if not self.block_pool.reset_prefix_cache(): + if not self._cacheable_block_pool().reset_prefix_cache(): return False if self.log_stats: assert self.prefix_cache_stats is not None @@ -502,7 +539,7 @@ def take_events(self) -> list[KVCacheEvent]: Returns: A list of KV cache events. """ - return self.block_pool.take_events() + return self._cacheable_block_pool().take_events() def get_blocks(self, request_id: str) -> KVCacheBlocks: """Get the blocks of a request.""" diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index b57e10b67faa..abb3c571792a 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -23,9 +23,11 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, + KVCachePoolConfig, KVCacheSpec, KVCacheTensor, MambaSpec, + MemoryModel, MLAAttentionSpec, SlidingWindowMLASpec, SlidingWindowSpec, @@ -875,21 +877,96 @@ def get_max_concurrency_for_kv_cache_config( """ Get the maximum concurrency for the given KV cache configuration. """ - num_layer_per_group = max( - len(group.layer_names) for group in kv_cache_config.kv_cache_groups - ) + if _has_request_constant_pools(kv_cache_config): + group_ids = [ + group_id + for group_id, group in enumerate(kv_cache_config.kv_cache_groups) + if group.kv_cache_spec.memory_model == MemoryModel.TOKEN_PROPORTIONAL + ] + if not group_ids: + return _get_request_constant_max_concurrency(kv_cache_config) + groups = [kv_cache_config.kv_cache_groups[group_id] for group_id in group_ids] + num_blocks = _get_num_blocks_for_group_ids(kv_cache_config, group_ids) + else: + groups = kv_cache_config.kv_cache_groups + num_blocks = kv_cache_config.num_blocks + + num_layer_per_group = max(len(group.layer_names) for group in groups) max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes( - vllm_config, (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups) - ) - memory_per_block = ( - kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes - * num_layer_per_group + vllm_config, (group.kv_cache_spec for group in groups) ) + memory_per_block = groups[0].kv_cache_spec.page_size_bytes * num_layer_per_group num_block_per_request = cdiv(max_memory_usage_per_request, memory_per_block) - max_concurrency = kv_cache_config.num_blocks / num_block_per_request + max_concurrency = num_blocks / num_block_per_request return max_concurrency +def _has_request_constant_pools(kv_cache_config: KVCacheConfig) -> bool: + return any( + pool.memory_model == MemoryModel.REQUEST_CONSTANT + for pool in kv_cache_config.pool_configs + ) + + +def _get_request_constant_max_concurrency( + kv_cache_config: KVCacheConfig, +) -> float: + pool_by_id = {pool.pool_id: pool for pool in kv_cache_config.pool_configs} + per_group_capacity = [] + for group_id, group in enumerate(kv_cache_config.kv_cache_groups): + if group.kv_cache_spec.memory_model != MemoryModel.REQUEST_CONSTANT: + continue + pool_id = kv_cache_config.group_to_pool_id[group_id] + pool = pool_by_id[pool_id] + per_group_capacity.append( + (pool.num_blocks - 1) / group.kv_cache_spec.blocks_per_request + ) + if not per_group_capacity: + return 0 + return min(per_group_capacity) + + +def _get_token_proportional_group_ids( + kv_cache_config: KVCacheConfig, +) -> list[int]: + group_ids = [ + group_id + for group_id, group in enumerate(kv_cache_config.kv_cache_groups) + if group.kv_cache_spec.memory_model == MemoryModel.TOKEN_PROPORTIONAL + ] + if not group_ids: + raise NotImplementedError( + "KV cache capacity reporting requires at least one " + "TOKEN_PROPORTIONAL group." + ) + return group_ids + + +def _get_num_blocks_for_group_ids( + kv_cache_config: KVCacheConfig, + group_ids: list[int], +) -> int: + pool_ids = {kv_cache_config.group_to_pool_id[group_id] for group_id in group_ids} + return sum( + pool.num_blocks + for pool in kv_cache_config.pool_configs + if pool.pool_id in pool_ids + ) + + +def get_token_proportional_kv_cache_capacity_tokens( + kv_cache_config: KVCacheConfig, +) -> int: + """Return token capacity represented by TOKEN_PROPORTIONAL pools only.""" + token_group_ids = _get_token_proportional_group_ids(kv_cache_config) + min_block_size = min( + kv_cache_config.kv_cache_groups[group_id].kv_cache_spec.block_size + for group_id in token_group_ids + ) + num_blocks = _get_num_blocks_for_group_ids(kv_cache_config, token_group_ids) + return num_blocks // len(token_group_ids) * min_block_size + + def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int: """ Override the number of kv cache blocks if `num_gpu_blocks_override` is set. @@ -1049,6 +1126,13 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo return not kv_cache_spec +def _has_request_constant_specs(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: + return any( + spec.memory_model == MemoryModel.REQUEST_CONSTANT + for spec in kv_cache_spec.values() + ) + + def _get_kv_cache_groups_uniform_page_size( kv_cache_spec: dict[str, KVCacheSpec], ) -> list[KVCacheGroupSpec]: @@ -1228,6 +1312,140 @@ def _get_kv_cache_config_deepseek_v4( return num_blocks, kv_cache_tensors +def _has_request_constant_groups( + kv_cache_groups: list[KVCacheGroupSpec], +) -> bool: + return any( + group.kv_cache_spec.memory_model == MemoryModel.REQUEST_CONSTANT + for group in kv_cache_groups + ) + + +def _get_request_constant_num_blocks( + vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec +) -> int: + blocks_per_request = kv_cache_spec.blocks_per_request + assert blocks_per_request > 0 + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + assert max_num_seqs > 0 + return max_num_seqs * blocks_per_request + 1 + + +def _get_request_constant_reserved_bytes( + vllm_config: VllmConfig, kv_cache_groups: list[KVCacheGroupSpec] +) -> int: + return sum( + _get_request_constant_num_blocks(vllm_config, group.kv_cache_spec) + * group.kv_cache_spec.physical_page_size_bytes + * len(group.layer_names) + for group in kv_cache_groups + if group.kv_cache_spec.memory_model == MemoryModel.REQUEST_CONSTANT + ) + + +def _get_legacy_num_blocks_from_pool_configs( + pool_configs: tuple[KVCachePoolConfig, ...], +) -> int: + assert pool_configs + for pool in pool_configs: + if pool.memory_model == MemoryModel.TOKEN_PROPORTIONAL: + return pool.num_blocks + return pool_configs[0].num_blocks + + +def _get_kv_cache_config_mixed_memory_model( + vllm_config: VllmConfig, + kv_cache_groups: list[KVCacheGroupSpec], + available_memory: int, +) -> KVCacheConfig: + token_group_ids = [ + group_id + for group_id, group in enumerate(kv_cache_groups) + if group.kv_cache_spec.memory_model == MemoryModel.TOKEN_PROPORTIONAL + ] + request_constant_group_ids = [ + group_id + for group_id, group in enumerate(kv_cache_groups) + if group.kv_cache_spec.memory_model == MemoryModel.REQUEST_CONSTANT + ] + + kv_cache_tensors: list[KVCacheTensor] = [] + pool_configs: list[KVCachePoolConfig] = [] + group_to_pool_id = [-1] * len(kv_cache_groups) + + request_constant_pool_specs: list[tuple[int, KVCachePoolConfig]] = [] + reserved_bytes = _get_request_constant_reserved_bytes(vllm_config, kv_cache_groups) + next_pool_id = 1 if token_group_ids else 0 + for group_id in request_constant_group_ids: + group = kv_cache_groups[group_id] + spec = group.kv_cache_spec + num_blocks = _get_request_constant_num_blocks(vllm_config, spec) + pool_config = KVCachePoolConfig( + pool_id=next_pool_id, + memory_model=MemoryModel.REQUEST_CONSTANT, + group_ids=(group_id,), + num_blocks=num_blocks, + accounting_page_size_bytes=spec.accounting_page_size_bytes, + physical_page_size_bytes=spec.physical_page_size_bytes, + ) + request_constant_pool_specs.append((group_id, pool_config)) + next_pool_id += 1 + + if reserved_bytes >= available_memory: + raise ValueError( + "REQUEST_CONSTANT KV cache reservation " + f"({format_gib(reserved_bytes)} GiB) is not smaller than the " + f"available KV cache memory ({format_gib(available_memory)} GiB). " + "Try increasing `gpu_memory_utilization` or decreasing " + "`max_num_seqs`." + ) + + if token_group_ids: + token_groups = [kv_cache_groups[group_id] for group_id in token_group_ids] + token_kv_cache_config = get_kv_cache_config_from_groups( + vllm_config, token_groups, available_memory - reserved_bytes + ) + if token_kv_cache_config.num_blocks <= 0: + raise ValueError( + "No available memory remains for TOKEN_PROPORTIONAL KV cache " + "blocks after REQUEST_CONSTANT KV cache reservation." + ) + token_pool_config = token_kv_cache_config.pool_configs[0] + pool_configs.append( + replace( + token_pool_config, + pool_id=0, + group_ids=tuple(token_group_ids), + ) + ) + for group_id in token_group_ids: + group_to_pool_id[group_id] = 0 + kv_cache_tensors.extend(token_kv_cache_config.kv_cache_tensors) + + for group_id, pool_config in request_constant_pool_specs: + group = kv_cache_groups[group_id] + spec = group.kv_cache_spec + group_to_pool_id[group_id] = pool_config.pool_id + pool_configs.append(pool_config) + for layer_name in group.layer_names: + kv_cache_tensors.append( + KVCacheTensor( + size=pool_config.num_blocks * spec.physical_page_size_bytes, + shared_by=[layer_name], + ) + ) + + assert all(pool_id >= 0 for pool_id in group_to_pool_id) + pool_config_tuple = tuple(pool_configs) + return KVCacheConfig( + num_blocks=_get_legacy_num_blocks_from_pool_configs(pool_config_tuple), + kv_cache_tensors=kv_cache_tensors, + kv_cache_groups=kv_cache_groups, + pool_configs=pool_config_tuple, + group_to_pool_id=tuple(group_to_pool_id), + ) + + def get_kv_cache_config_from_groups( vllm_config: VllmConfig, kv_cache_groups: list[KVCacheGroupSpec], @@ -1253,6 +1471,17 @@ def get_kv_cache_config_from_groups( kv_cache_groups=kv_cache_groups, ) + if _has_request_constant_groups(kv_cache_groups): + if vllm_config.cache_config.enable_prefix_caching: + raise NotImplementedError( + "Prefix caching with REQUEST_CONSTANT groups is not yet " + "supported. Either disable prefix caching or use only " + "TOKEN_PROPORTIONAL specs." + ) + return _get_kv_cache_config_mixed_memory_model( + vllm_config, kv_cache_groups, available_memory + ) + # Determine how model runners should initialize the KV cache tensors. if len(kv_cache_groups) == 1 and isinstance( kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs @@ -1610,27 +1839,12 @@ def _annotate_eagle_groups_deepseek_v4( break -def get_kv_cache_groups( +def _get_token_proportional_kv_cache_groups( vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec] ) -> list[KVCacheGroupSpec]: - """ - Split the layers in the model into groups with the same KV cache spec. - - Args: - vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of each attention layer in the model - - Returns: - The generated KVCacheGroups - """ if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: unify_hybrid_kv_cache_specs(kv_cache_spec) - if is_kv_cache_type_attention_free(kv_cache_spec): - # This returns an empty list to allow for the KVCacheManager to handle - # attention free models. - return [] - if is_kv_cache_spec_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for @@ -1661,6 +1875,68 @@ def get_kv_cache_groups( return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) +def _get_request_constant_kv_cache_groups( + kv_cache_spec: dict[str, KVCacheSpec], +) -> list[KVCacheGroupSpec]: + """Group request-constant specs without page-size unification.""" + same_spec_layers: dict[KVCacheSpec, list[str]] = defaultdict(list) + for layer_name, layer_spec in kv_cache_spec.items(): + same_spec_layers[layer_spec].append(layer_name) + return create_kv_cache_group_specs(kv_cache_spec, list(same_spec_layers.values())) + + +def _get_memory_model_aware_kv_cache_groups( + vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec] +) -> list[KVCacheGroupSpec]: + token_proportional_specs = { + layer_name: spec + for layer_name, spec in kv_cache_spec.items() + if spec.memory_model == MemoryModel.TOKEN_PROPORTIONAL + } + request_constant_specs = { + layer_name: spec + for layer_name, spec in kv_cache_spec.items() + if spec.memory_model == MemoryModel.REQUEST_CONSTANT + } + + kv_cache_groups: list[KVCacheGroupSpec] = [] + if token_proportional_specs: + kv_cache_groups.extend( + _get_token_proportional_kv_cache_groups( + vllm_config, token_proportional_specs + ) + ) + if request_constant_specs: + kv_cache_groups.extend( + _get_request_constant_kv_cache_groups(request_constant_specs) + ) + return kv_cache_groups + + +def get_kv_cache_groups( + vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec] +) -> list[KVCacheGroupSpec]: + """ + Split the layers in the model into groups with the same KV cache spec. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of each attention layer in the model + + Returns: + The generated KVCacheGroups + """ + if is_kv_cache_type_attention_free(kv_cache_spec): + # This returns an empty list to allow for the KVCacheManager to handle + # attention free models. + return [] + + if _has_request_constant_specs(kv_cache_spec): + return _get_memory_model_aware_kv_cache_groups(vllm_config, kv_cache_spec) + + return _get_token_proportional_kv_cache_groups(vllm_config, kv_cache_spec) + + def generate_scheduler_kv_cache_config( kv_cache_configs: list[KVCacheConfig], ) -> KVCacheConfig: @@ -1670,6 +1946,26 @@ def generate_scheduler_kv_cache_config( assert all( [cfg.num_blocks == kv_cache_configs[0].num_blocks for cfg in kv_cache_configs] ) + assert all( + [ + cfg.group_to_pool_id == kv_cache_configs[0].group_to_pool_id + for cfg in kv_cache_configs + ] + ) + assert all( + [ + _get_pool_config_structure(cfg) + == _get_pool_config_structure(kv_cache_configs[0]) + for cfg in kv_cache_configs + ] + ) + assert all( + [ + tuple(pool.num_blocks for pool in cfg.pool_configs) + == tuple(pool.num_blocks for pool in kv_cache_configs[0].pool_configs) + for cfg in kv_cache_configs + ] + ) # All workers have the same kv_cache_config except layer names, so use # an arbitrary one to initialize the scheduler. cfg = copy.deepcopy(kv_cache_configs[0]) @@ -1680,9 +1976,23 @@ def generate_scheduler_kv_cache_config( group.kv_cache_spec = next( iter(group.kv_cache_spec.kv_cache_specs.values()) ) + cfg.refresh_legacy_pool_metadata() return cfg +def _get_pool_config_structure(kv_cache_config: KVCacheConfig): + return tuple( + ( + pool.pool_id, + pool.memory_model, + pool.group_ids, + pool.accounting_page_size_bytes, + pool.physical_page_size_bytes, + ) + for pool in kv_cache_config.pool_configs + ) + + def _report_kv_cache_config( vllm_config: VllmConfig, kv_cache_config: KVCacheConfig ) -> None: @@ -1729,6 +2039,22 @@ def _max_memory_usage_bytes_from_groups( if not kv_cache_groups: return 0 + if _has_request_constant_groups(kv_cache_groups): + request_constant_memory = _get_request_constant_reserved_bytes( + vllm_config, kv_cache_groups + ) + token_proportional_groups = [ + group + for group in kv_cache_groups + if group.kv_cache_spec.memory_model == MemoryModel.TOKEN_PROPORTIONAL + ] + token_proportional_memory = ( + _max_memory_usage_bytes_from_groups(vllm_config, token_proportional_groups) + if token_proportional_groups + else 0 + ) + return request_constant_memory + token_proportional_memory + if len(kv_cache_groups) == 1 and isinstance( kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs ): @@ -1919,6 +2245,78 @@ def _project_kv_cache_groups_to_worker( return projected_groups +def _get_tensor_pool_id( + kv_cache_config: KVCacheConfig, + tensor: KVCacheTensor, +) -> int: + layer_to_group_id = { + layer_name: group_id + for group_id, group in enumerate(kv_cache_config.kv_cache_groups) + for layer_name in group.layer_names + } + pool_ids = { + kv_cache_config.group_to_pool_id[layer_to_group_id[layer_name]] + for layer_name in tensor.shared_by + } + assert len(pool_ids) == 1 + return pool_ids.pop() + + +def _normalize_kv_cache_config_num_blocks( + kv_cache_configs: list[KVCacheConfig], +) -> None: + if not kv_cache_configs: + return + if not kv_cache_configs[0].pool_configs: + min_num_blocks = min( + kv_cache_config.num_blocks for kv_cache_config in kv_cache_configs + ) + for kv_cache_config in kv_cache_configs: + kv_cache_config.num_blocks = min_num_blocks + kv_cache_config.refresh_legacy_pool_metadata() + return + + num_pools = len(kv_cache_configs[0].pool_configs) + assert all(len(config.pool_configs) == num_pools for config in kv_cache_configs) + + normalized_pool_num_blocks: list[int] = [] + for pool_id in range(num_pools): + memory_model = kv_cache_configs[0].pool_configs[pool_id].memory_model + assert all( + config.pool_configs[pool_id].memory_model == memory_model + for config in kv_cache_configs + ) + pool_num_blocks = [ + config.pool_configs[pool_id].num_blocks for config in kv_cache_configs + ] + if memory_model == MemoryModel.TOKEN_PROPORTIONAL: + normalized_pool_num_blocks.append(min(pool_num_blocks)) + else: + assert len(set(pool_num_blocks)) == 1 + normalized_pool_num_blocks.append(pool_num_blocks[0]) + + for kv_cache_config in kv_cache_configs: + old_pool_num_blocks = tuple( + pool.num_blocks for pool in kv_cache_config.pool_configs + ) + kv_cache_config.pool_configs = tuple( + replace(pool, num_blocks=normalized_pool_num_blocks[pool.pool_id]) + for pool in kv_cache_config.pool_configs + ) + kv_cache_config.num_blocks = _get_legacy_num_blocks_from_pool_configs( + kv_cache_config.pool_configs + ) + + for tensor in kv_cache_config.kv_cache_tensors: + pool_id = _get_tensor_pool_id(kv_cache_config, tensor) + old_num_blocks = old_pool_num_blocks[pool_id] + new_num_blocks = normalized_pool_num_blocks[pool_id] + if old_num_blocks == new_num_blocks: + continue + assert tensor.size % old_num_blocks == 0 + tensor.size = tensor.size // old_num_blocks * new_num_blocks + + def get_kv_cache_configs( vllm_config: VllmConfig, kv_cache_specs: list[dict[str, KVCacheSpec]], @@ -1994,13 +2392,29 @@ def get_kv_cache_configs( if not groups: adjusted_memory.append(avail_mem) continue - bytes_per_block = _pool_bytes_per_block(groups) + if _has_request_constant_groups(groups): + token_groups = [ + group + for group in groups + if group.kv_cache_spec.memory_model + == MemoryModel.TOKEN_PROPORTIONAL + ] + request_constant_bytes = _get_request_constant_reserved_bytes( + vllm_config, groups + ) + bytes_per_block = ( + _pool_bytes_per_block(token_groups) if token_groups else 0 + ) + else: + request_constant_bytes = 0 + bytes_per_block = _pool_bytes_per_block(groups) + profiled_blocks = avail_mem // bytes_per_block if bytes_per_block else 0 logger.info( "Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d", - avail_mem // bytes_per_block, + profiled_blocks, override, ) - adjusted_memory.append(override * bytes_per_block) + adjusted_memory.append(request_constant_bytes + override * bytes_per_block) available_memory = adjusted_memory if vllm_config.model_config.original_max_model_len == -1: @@ -2032,21 +2446,11 @@ def get_kv_cache_configs( ) ) - # Change the num_blocks of each rank to the smallest among all ranks. - # We also need to shrink the tensor size proportionally to avoid - # allocating unused memory. - min_num_blocks = min( - kv_cache_config.num_blocks for kv_cache_config in kv_cache_configs - ) + # Change token-proportional num_blocks of each rank to the smallest among + # all ranks. Request-constant pool sizes are deterministic from + # max_num_seqs and are asserted equal rather than normalized. + _normalize_kv_cache_config_num_blocks(kv_cache_configs) for kv_cache_config in kv_cache_configs: - num_blocks_old = kv_cache_config.num_blocks - kv_cache_config.num_blocks = min_num_blocks - - # Shrink tensor size proportionally - for tensor in kv_cache_config.kv_cache_tensors: - assert tensor.size % num_blocks_old == 0 - tensor.size = tensor.size // num_blocks_old * min_num_blocks - if len(kv_cache_config.kv_cache_groups) > 0: _report_kv_cache_config(vllm_config, kv_cache_config) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 032767cdf3b0..042074c2f095 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -38,6 +38,9 @@ ) from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector +from vllm.v1.core.kv_cache_utils import ( + get_token_proportional_kv_cache_capacity_tokens, +) from vllm.v1.core.sched.interface import PauseState, SchedulerInterface from vllm.v1.core.sched.output import ( CachedRequestData, @@ -52,7 +55,7 @@ ) from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MemoryModel from vllm.v1.metrics.perf import ModelMetrics, PerfStats from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput @@ -277,16 +280,24 @@ def __init__( if isinstance(group.kv_cache_spec, AttentionSpec): self.routed_experts_attn_gid = gid break - min_block_size = min( - [ - group.kv_cache_spec.block_size - for group in kv_cache_config.kv_cache_groups - ] - ) - num_groups = len(kv_cache_config.kv_cache_groups) - self.max_num_kv_tokens = ( - kv_cache_config.num_blocks // num_groups - ) * min_block_size + if any( + pool.memory_model == MemoryModel.REQUEST_CONSTANT + for pool in kv_cache_config.pool_configs + ): + self.max_num_kv_tokens = ( + get_token_proportional_kv_cache_capacity_tokens(kv_cache_config) + ) + else: + min_block_size = min( + [ + group.kv_cache_spec.block_size + for group in kv_cache_config.kv_cache_groups + ] + ) + num_groups = len(kv_cache_config.kv_cache_groups) + self.max_num_kv_tokens = ( + kv_cache_config.num_blocks // num_groups + ) * min_block_size dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size if pcp_size * dcp_size > 1: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index e8d3a6f75688..ce9d36221496 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -4,9 +4,14 @@ from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Sequence +from typing import cast from vllm.utils.math_utils import cdiv -from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.block_pool import ( + BlockPool, + BlockPoolProtocol, + CacheableBlockPoolProtocol, +) from vllm.v1.core.kv_cache_utils import ( BlockHashList, BlockHashWithGroupId, @@ -18,6 +23,7 @@ FullAttentionSpec, KVCacheSpec, MambaSpec, + MemoryModel, MLAAttentionSpec, SinkFullAttentionSpec, SlidingWindowMLASpec, @@ -36,7 +42,7 @@ class SingleTypeKVCacheManager(ABC): def __init__( self, kv_cache_spec: KVCacheSpec, - block_pool: BlockPool, + block_pool: BlockPoolProtocol, enable_caching: bool, kv_cache_group_id: int, dcp_world_size: int = 1, @@ -81,6 +87,26 @@ def __init__( self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block + def _should_record_new_block_ids_for_zeroing(self) -> bool: + """Return whether newly allocated blocks should be zeroed. + + This preserves the legacy behavior exactly: only full-attention style + KV blocks are zeroed by the GPU worker. The broader + ``requires_block_zeroing_on_alloc`` spec property is not wired here yet + because the current zeroing kernel only targets attention KV tensors. + """ + return type(self.kv_cache_spec) in (FullAttentionSpec, TQFullAttentionSpec) + + def _cacheable_block_pool(self) -> CacheableBlockPoolProtocol: + """Return the block pool as a prefix-cache-capable pool. + + Compact/request-constant pools only implement the allocation protocol. + Prefix-cache paths are guarded by ``enable_caching`` and must only run + with a cacheable pool. + """ + assert self.enable_caching + return cast(CacheableBlockPoolProtocol, self.block_pool) + @classmethod def _get_num_evictable_blocks(cls, blocks: Sequence[KVCacheBlock]): return sum(blk.ref_cnt == 0 and not blk.is_null for blk in blocks) @@ -215,7 +241,7 @@ def allocate_new_computed_blocks( # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - self.block_pool.touch(new_computed_blocks) + self._cacheable_block_pool().touch(new_computed_blocks) else: assert not any(new_computed_blocks), ( "Computed blocks should be empty when prefix caching is disabled" @@ -236,7 +262,7 @@ def allocate_new_computed_blocks( cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks) ) req_blocks.extend(allocated_blocks) - if type(self.kv_cache_spec) in (FullAttentionSpec, TQFullAttentionSpec): + if self._should_record_new_block_ids_for_zeroing(): self.new_block_ids.extend(b.block_id for b in allocated_blocks) def allocate_new_blocks( @@ -264,7 +290,7 @@ def allocate_new_blocks( else: new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) - if type(self.kv_cache_spec) in (FullAttentionSpec, TQFullAttentionSpec): + if self._should_record_new_block_ids_for_zeroing(): self.new_block_ids.extend(b.block_id for b in new_blocks) return new_blocks @@ -289,7 +315,7 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None: if num_cached_blocks >= num_full_blocks: return - self.block_pool.cache_full_blocks( + self._cacheable_block_pool().cache_full_blocks( request=request, blocks=self.req_to_blocks[request.request_id], num_cached_blocks=num_cached_blocks, @@ -340,7 +366,7 @@ def find_longest_cache_hit( block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], - block_pool: BlockPool, + block_pool: CacheableBlockPoolProtocol, kv_cache_spec: KVCacheSpec, use_eagle: bool, alignment_tokens: int, @@ -450,7 +476,7 @@ def find_longest_cache_hit( block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], - block_pool: BlockPool, + block_pool: CacheableBlockPoolProtocol, kv_cache_spec: KVCacheSpec, use_eagle: bool, alignment_tokens: int, @@ -515,7 +541,7 @@ def find_longest_cache_hit( block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], - block_pool: BlockPool, + block_pool: CacheableBlockPoolProtocol, kv_cache_spec: KVCacheSpec, use_eagle: bool, alignment_tokens: int, @@ -652,7 +678,7 @@ def find_longest_cache_hit( block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], - block_pool: BlockPool, + block_pool: CacheableBlockPoolProtocol, kv_cache_spec: KVCacheSpec, use_eagle: bool, alignment_tokens: int, @@ -793,12 +819,17 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class MambaManager(SingleTypeKVCacheManager): def __init__( - self, kv_cache_spec: MambaSpec, block_pool: BlockPool, **kwargs + self, kv_cache_spec: MambaSpec, block_pool: BlockPoolProtocol, **kwargs ) -> None: super().__init__(kv_cache_spec, block_pool, **kwargs) self.cached_blocks_this_step: set[BlockHashWithGroupId] = set() self.mamba_cache_mode = kv_cache_spec.mamba_cache_mode self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks + self.is_request_constant = ( + kv_cache_spec.memory_model == MemoryModel.REQUEST_CONSTANT + ) + if self.is_request_constant: + assert not self.enable_caching if self.mamba_cache_mode == "align": # Mapping from request ID to the index of the block # allocated in the previous step @@ -812,7 +843,7 @@ def find_longest_cache_hit( block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], - block_pool: BlockPool, + block_pool: CacheableBlockPoolProtocol, kv_cache_spec: KVCacheSpec, use_eagle: bool, alignment_tokens: int, @@ -864,6 +895,13 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No # that we might actually need. num_computed_tokens = max(0, num_computed_tokens - self.num_speculative_blocks) + if self.is_request_constant and self.mamba_cache_mode != "align": + # Non-align compact Mamba keeps exactly one current state block + # plus optional speculative state blocks per request. Token-window + # based skipped-block eviction would incorrectly free that compact + # state because its block table is not token-proportional. + return + super().remove_skipped_blocks(request_id, num_computed_tokens) if self.mamba_cache_mode == "align": # `last_state_block_idx` refers to the block index allocated two steps ago. @@ -890,6 +928,13 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: """ return 0 + def _get_request_constant_num_blocks_to_allocate(self, request_id: str) -> int: + assert isinstance(self.kv_cache_spec, MambaSpec) + held_blocks = sum( + not block.is_null for block in self.req_to_blocks.get(request_id, ()) + ) + return max(self.kv_cache_spec.blocks_per_request - held_blocks, 0) + def get_num_blocks_to_allocate( self, request_id: str, @@ -909,6 +954,9 @@ def get_num_blocks_to_allocate( # that kv_cache_manager will think there is no enough blocks to allocate now # and don't schedule it in the current step. return self.block_pool.num_gpu_blocks + 1 + if self.is_request_constant and self.mamba_cache_mode != "align": + assert len(new_computed_blocks) == 0 + return self._get_request_constant_num_blocks_to_allocate(request_id) if self.mamba_cache_mode != "align": # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. @@ -961,6 +1009,16 @@ def allocate_new_blocks( self, request_id: str, num_tokens: int, num_tokens_main_model: int ) -> list[KVCacheBlock]: assert isinstance(self.kv_cache_spec, MambaSpec) + if self.is_request_constant and self.mamba_cache_mode != "align": + compact_req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] + num_new_blocks = self._get_request_constant_num_blocks_to_allocate( + request_id + ) + if num_new_blocks <= 0: + return [] + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) + compact_req_blocks.extend(new_blocks) + return new_blocks if self.mamba_cache_mode != "align": # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. @@ -1039,6 +1097,16 @@ def free(self, request_id: str) -> None: if self.mamba_cache_mode == "align": self._allocated_block_reqs.discard(request_id) self.last_state_block_idx.pop(request_id, None) + if self.is_request_constant: + # CompactBlockPool rejects freeing the null sentinel. Align-mode + # request block tables can contain null padding, so filter them + # before returning real compact blocks to the pool. + req_blocks = self.req_to_blocks.pop(request_id, []) + self.block_pool.free_blocks( + block for block in reversed(req_blocks) if not block.is_null + ) + self.num_cached_block.pop(request_id, None) + return super().free(request_id) def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: @@ -1096,7 +1164,7 @@ def find_longest_cache_hit( block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], - block_pool: BlockPool, + block_pool: CacheableBlockPoolProtocol, kv_cache_spec: KVCacheSpec, use_eagle: bool, alignment_tokens: int, @@ -1136,7 +1204,7 @@ def __init__( sink_len = kv_cache_spec.sink_len assert sink_len is not None and sink_len > 0 and sink_len % self.block_size == 0 num_sink_block = sink_len // self.block_size - self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block) + self.sink_blocks = block_pool.free_block_queue.popleft_n(num_sink_block) spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 19438fb1e42d..ffa20e36c908 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -6,7 +6,7 @@ import copy from collections import Counter from dataclasses import dataclass, fields, replace -from enum import IntEnum +from enum import Enum, IntEnum from math import prod from typing import TYPE_CHECKING @@ -77,6 +77,13 @@ def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool: return get_kv_quant_mode(kv_cache_dtype).is_per_token_head +class MemoryModel(Enum): + """How a KV cache spec's memory scales with request properties.""" + + TOKEN_PROPORTIONAL = "token_proportional" + REQUEST_CONSTANT = "request_constant" + + @dataclass(frozen=True) class KVCacheSpec: """ @@ -100,6 +107,34 @@ def page_size_bytes(self) -> int: def storage_block_size(self) -> int: return self.block_size + @property + def memory_model(self) -> MemoryModel: + """How memory usage scales for this KV cache spec.""" + return MemoryModel.TOKEN_PROPORTIONAL + + @property + def blocks_per_request(self) -> int: + """Maximum compact-pool blocks held by one request. + + This is only meaningful for ``REQUEST_CONSTANT`` specs. + """ + return 1 + + @property + def accounting_page_size_bytes(self) -> int: + """Bytes used for allocator accounting.""" + return self.page_size_bytes + + @property + def physical_page_size_bytes(self) -> int: + """Bytes the backing cache tensor physically stores per page.""" + return self.page_size_bytes + + @property + def requires_block_zeroing_on_alloc(self) -> bool: + """Whether reused blocks must be zeroed before re-allocation.""" + return True + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: """ The maximum possible memory usage of this KV cache in bytes. @@ -538,20 +573,48 @@ class MambaSpec(KVCacheSpec): @property def page_size_bytes(self) -> int: - page_size = sum( - prod(shape) * get_dtype_size(dtype) - for (shape, dtype) in zip(self.shapes, self.dtypes) - ) + page_size = self.physical_page_size_bytes if self.page_size_padded is not None: assert self.page_size_padded >= page_size return self.page_size_padded return page_size + @property + def physical_page_size_bytes(self) -> int: + return sum( + prod(shape) * get_dtype_size(dtype) + for (shape, dtype) in zip(self.shapes, self.dtypes) + ) + + @property + def memory_model(self) -> MemoryModel: + # "all" mode preserves the legacy token-proportional shared-pool path, + # including prefix-caching compatibility. Other modes store compact + # O(1)-per-request state. + if self.mamba_cache_mode == "all": + return MemoryModel.TOKEN_PROPORTIONAL + return MemoryModel.REQUEST_CONSTANT + + @property + def blocks_per_request(self) -> int: + # "align" may transiently hold previous + current state blocks during a + # state-block transition. "none" holds exactly one state block. + # Speculative decoding adds one compact slot per speculative block. + if self.mamba_cache_mode == "align": + return 2 + self.num_speculative_blocks + return 1 + self.num_speculative_blocks + + @property + def requires_block_zeroing_on_alloc(self) -> bool: + # The current KVBlockZeroer skips Mamba tensors in all modes. Mamba + # state isolation is handled by the kernel/state-copy path instead. + return False + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - if vllm_config.cache_config.mamba_cache_mode == "all": + if self.mamba_cache_mode == "all": max_model_len = vllm_config.model_config.max_model_len return cdiv(max_model_len, self.block_size) * self.page_size_bytes - elif vllm_config.cache_config.mamba_cache_mode == "align": + elif self.mamba_cache_mode == "align": return self.page_size_bytes * (2 + self.num_speculative_blocks) else: return self.page_size_bytes * (1 + self.num_speculative_blocks) @@ -644,6 +707,12 @@ class UniformTypeKVCacheSpecs(KVCacheSpec): def page_size_bytes(self) -> int: return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values()) + @property + def physical_page_size_bytes(self) -> int: + return sum( + spec.physical_page_size_bytes for spec in self.kv_cache_specs.values() + ) + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_num_pages = max( cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes) @@ -756,6 +825,22 @@ class KVCacheGroupSpec: is_eagle_group: bool = False +@dataclass(frozen=True) +class KVCachePoolConfig: + """Metadata for one KV cache block-pool namespace. + + ``num_blocks`` includes the reserved null sentinel block. Therefore the + number of allocatable blocks in the pool is ``num_blocks - 1``. + """ + + pool_id: int + memory_model: MemoryModel + group_ids: tuple[int, ...] + num_blocks: int + accounting_page_size_bytes: int + physical_page_size_bytes: int + + @dataclass class KVCacheConfig: """ @@ -774,6 +859,123 @@ class KVCacheConfig: For models with multiple types of attention, there will be multiple groups, see `_get_kv_cache_config_uniform_page_size` for more details. """ + pool_configs: tuple[KVCachePoolConfig, ...] = () + """Metadata for KV cache block-pool namespaces.""" + group_to_pool_id: tuple[int, ...] = () + """Mapping from KV cache group id to pool id.""" + + def __post_init__(self) -> None: + if len(self.kv_cache_groups) == 0: + assert self.pool_configs == () + assert self.group_to_pool_id == () + return + + if not self.pool_configs and not self.group_to_pool_id: + pool_config, group_to_pool_id = self._make_legacy_pool_metadata() + self.pool_configs = (pool_config,) + self.group_to_pool_id = group_to_pool_id + return + + assert len(self.group_to_pool_id) == len(self.kv_cache_groups) + pool_ids = {pool.pool_id for pool in self.pool_configs} + assert pool_ids == set(range(len(self.pool_configs))) + assert set(self.group_to_pool_id).issubset(pool_ids) + + def _make_legacy_pool_metadata( + self, + num_blocks: int | None = None, + ) -> tuple[KVCachePoolConfig, tuple[int, ...]]: + specs = [group.kv_cache_spec for group in self.kv_cache_groups] + memory_models = {spec.memory_model for spec in specs} + assert len(memory_models) == 1 + accounting_page_sizes = {spec.accounting_page_size_bytes for spec in specs} + assert len(accounting_page_sizes) == 1 + physical_page_sizes = {spec.physical_page_size_bytes for spec in specs} + if len(physical_page_sizes) == 1: + pool_physical_page_size_bytes = physical_page_sizes.pop() + else: + pool_physical_page_size_bytes = next(iter(accounting_page_sizes)) + + pool_config = KVCachePoolConfig( + pool_id=0, + memory_model=memory_models.pop(), + group_ids=tuple(range(len(self.kv_cache_groups))), + num_blocks=self.num_blocks if num_blocks is None else num_blocks, + accounting_page_size_bytes=accounting_page_sizes.pop(), + physical_page_size_bytes=pool_physical_page_size_bytes, + ) + return pool_config, tuple(0 for _ in self.kv_cache_groups) + + def _legacy_num_blocks_pool_id(self) -> int | None: + """Return the pool represented by the legacy ``num_blocks`` field.""" + if not self.pool_configs: + return None + for pool in self.pool_configs: + if pool.memory_model == MemoryModel.TOKEN_PROPORTIONAL: + return pool.pool_id + return self.pool_configs[0].pool_id + + def _refresh_multi_pool_metadata(self) -> None: + """Refresh explicit pool metadata without collapsing it to one pool.""" + assert self.pool_configs + assert len(self.group_to_pool_id) == len(self.kv_cache_groups) + + group_ids_by_pool: dict[int, list[int]] = { + pool.pool_id: [] for pool in self.pool_configs + } + for group_id, pool_id in enumerate(self.group_to_pool_id): + group_ids_by_pool[pool_id].append(group_id) + + legacy_pool_id = self._legacy_num_blocks_pool_id() + refreshed_pool_configs: list[KVCachePoolConfig] = [] + for pool in self.pool_configs: + group_ids = tuple(group_ids_by_pool[pool.pool_id]) + assert group_ids + specs = [ + self.kv_cache_groups[group_id].kv_cache_spec for group_id in group_ids + ] + memory_models = {spec.memory_model for spec in specs} + assert memory_models == {pool.memory_model} + + accounting_page_sizes = {spec.accounting_page_size_bytes for spec in specs} + assert len(accounting_page_sizes) == 1 + physical_page_sizes = {spec.physical_page_size_bytes for spec in specs} + if len(physical_page_sizes) == 1: + physical_page_size_bytes = physical_page_sizes.pop() + else: + physical_page_size_bytes = next(iter(accounting_page_sizes)) + + refreshed_pool_configs.append( + replace( + pool, + group_ids=group_ids, + num_blocks=( + self.num_blocks + if pool.pool_id == legacy_pool_id + else pool.num_blocks + ), + accounting_page_size_bytes=accounting_page_sizes.pop(), + physical_page_size_bytes=physical_page_size_bytes, + ) + ) + + self.pool_configs = tuple(refreshed_pool_configs) + + def refresh_legacy_pool_metadata(self) -> None: + """Regenerate pool metadata after mutating config fields.""" + if len(self.kv_cache_groups) == 0: + self.pool_configs = () + self.group_to_pool_id = () + return + if len(self.pool_configs) > 1 or any( + pool.memory_model == MemoryModel.REQUEST_CONSTANT + for pool in self.pool_configs + ): + self._refresh_multi_pool_metadata() + return + pool_config, group_to_pool_id = self._make_legacy_pool_metadata() + self.pool_configs = (pool_config,) + self.group_to_pool_id = group_to_pool_id @property def has_mamba_layers(self) -> bool: @@ -781,4 +983,11 @@ def has_mamba_layers(self) -> bool: @property def needs_kv_cache_zeroing(self) -> bool: + """Whether the GPU worker should initialize the attention KV zeroer. + + Despite the broad name, this gate currently controls the attention-only + ``KVBlockZeroer`` in hybrid models. The zeroer skips Mamba tensors, and + ``SingleTypeKVCacheManager`` only records full-attention block IDs for + zeroing. + """ return self.has_mamba_layers diff --git a/vllm/v1/simple_kv_offload/manager.py b/vllm/v1/simple_kv_offload/manager.py index 846526e5bee4..9d1a7f205d0d 100644 --- a/vllm/v1/simple_kv_offload/manager.py +++ b/vllm/v1/simple_kv_offload/manager.py @@ -5,7 +5,7 @@ import contextlib from collections.abc import Iterable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from vllm.config import VllmConfig from vllm.distributed.kv_events import KVCacheEvent @@ -21,6 +21,7 @@ from vllm.v1.kv_cache_interface import ( FullAttentionSpec, MambaSpec, + MemoryModel, SlidingWindowSpec, ) from vllm.v1.outputs import KVConnectorOutput @@ -120,7 +121,7 @@ def __init__( pcp_world_size=pcp_world_size, hash_block_size=self.block_size, ) - self.cpu_block_pool: BlockPool = self.cpu_coordinator.block_pool + self.cpu_block_pool = cast(BlockPool, self.cpu_coordinator.block_pool) # GPU block pool reference - bound after scheduler builds kv_cache_manager self._gpu_block_pool: BlockPool | None = None @@ -167,6 +168,15 @@ def _derive_cpu_config( from vllm.v1.kv_cache_interface import KVCacheTensor assert len(gpu_config.kv_cache_tensors) > 0 + if any( + pool.memory_model == MemoryModel.REQUEST_CONSTANT + for pool in gpu_config.pool_configs + ): + raise NotImplementedError( + "CPU KV cache offload with REQUEST_CONSTANT specs " + "(Mamba in 'none' or 'align' mode) is not supported. Set " + "mamba_cache_mode='all' or disable offload." + ) gpu_total_bytes = sum(t.size for t in gpu_config.kv_cache_tensors) num_gpu_blocks = gpu_config.num_blocks diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 226257581265..d9602671baa0 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -19,6 +19,7 @@ AttentionSpec, KVCacheConfig, KVCacheSpec, + MemoryModel, UniformTypeKVCacheSpecs, ) from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache @@ -30,6 +31,13 @@ class AttentionCGSupportInfo: min_cg_attn_backend: str | None = None +def get_block_layout_page_size_bytes(spec: KVCacheSpec) -> int: + """Page size used for block-count and stride math during reshape.""" + if spec.memory_model == MemoryModel.REQUEST_CONSTANT: + return spec.physical_page_size_bytes + return spec.page_size_bytes + + def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: kv_cache_spec: dict[str, KVCacheSpec] = {} layer_type = cast(type[Any], AttentionLayerBase) @@ -155,10 +163,16 @@ def _reshape_kv_cache( if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): kv_cache_spec = kv_cache_spec.kv_cache_specs[layer_name] assert isinstance(kv_cache_spec, AttentionSpec) + if kv_cache_spec.memory_model == MemoryModel.REQUEST_CONSTANT: + raise NotImplementedError( + "REQUEST_CONSTANT AttentionSpec is not supported. " + "Attention KV cache is token-proportional." + ) raw_tensor = kv_cache_raw_tensors[layer_name] - assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes + block_layout_page_size = get_block_layout_page_size_bytes(kv_cache_spec) + assert raw_tensor.numel() % block_layout_page_size == 0 + num_blocks = raw_tensor.numel() // block_layout_page_size attn_backend = attn_backends[layer_name] kv_cache_shape = attn_backend.get_kv_cache_shape( @@ -194,7 +208,7 @@ def _reshape_kv_cache( # standard attention backends whose shape starts with # a K/V dimension of size 2. dtype_size = get_dtype_size(dtype) - page_stride = kv_cache_spec.page_size_bytes // dtype_size + page_stride = block_layout_page_size // dtype_size strides = list(torch.empty(kv_cache_shape).stride()) strides[inv_order[0]] = page_stride kv_cache = torch.as_strided( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bcab2ca2d4c2..53ee27682b60 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -143,6 +143,7 @@ KVCacheGroupSpec, KVCacheSpec, MambaSpec, + MemoryModel, SlidingWindowSpec, UniformTypeKVCacheSpecs, ) @@ -188,6 +189,7 @@ ) from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin +from vllm.v1.worker.gpu.attn_utils import get_block_layout_page_size_bytes from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper @@ -6648,9 +6650,15 @@ def _reshape_kv_cache_tensors( if layer_name in self.runner_only_attn_layers: continue raw_tensor = kv_cache_raw_tensors[layer_name] - assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes + block_layout_page_size = get_block_layout_page_size_bytes(kv_cache_spec) + assert raw_tensor.numel() % block_layout_page_size == 0 + num_blocks = raw_tensor.numel() // block_layout_page_size if isinstance(kv_cache_spec, AttentionSpec): + if kv_cache_spec.memory_model == MemoryModel.REQUEST_CONSTANT: + raise NotImplementedError( + "REQUEST_CONSTANT AttentionSpec is not supported. " + "Attention KV cache is token-proportional." + ) has_attn = True num_blocks_per_kv_block = ( kv_cache_spec.block_size // kernel_block_size @@ -6701,7 +6709,7 @@ def _reshape_kv_cache_tensors( # standard attention backends whose shape starts with # a K/V dimension of size 2. dtype_size = get_dtype_size(dtype) - page_stride = kv_cache_spec.page_size_bytes // dtype_size + page_stride = block_layout_page_size // dtype_size strides = list(torch.empty(kv_cache_shape).stride()) strides[inv_order[0]] = page_stride kv_cache = torch.as_strided( @@ -6721,9 +6729,7 @@ def _reshape_kv_cache_tensors( storage_offset_bytes = 0 for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): dtype_size = get_dtype_size(dtype) - num_element_per_page = ( - kv_cache_spec.page_size_bytes // dtype_size - ) + num_element_per_page = block_layout_page_size // dtype_size target_shape = (num_blocks, *shape) stride = torch.empty(target_shape).stride() target_stride = (num_element_per_page, *stride[1:]) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 4fc1aff94fed..524e7b3315ef 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -18,12 +18,13 @@ from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.v1.attention.backend import AttentionBackend -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MemoryModel from vllm.v1.outputs import ( EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput, ModelRunnerOutput, ) +from vllm.v1.worker.gpu.attn_utils import get_block_layout_page_size_bytes from vllm.v1.worker.utils import AttentionGroup if TYPE_CHECKING: @@ -219,6 +220,15 @@ def allocate_uniform_kv_caches( attn_group = attn_groups[0][0] kv_cache_spec = attn_group.kv_cache_spec assert isinstance(kv_cache_spec, AttentionSpec) + if any( + group.kv_cache_spec.memory_model == MemoryModel.REQUEST_CONSTANT + for groups in attn_groups + for group in groups + ): + raise NotImplementedError( + "Cross-layer KV connector does not support REQUEST_CONSTANT " + "specs. Multi-pool connector support is out of scope." + ) tensor_sizes = set( kv_cache_tensor.size for kv_cache_tensor in kv_cache_config.kv_cache_tensors @@ -226,7 +236,7 @@ def allocate_uniform_kv_caches( assert len(tensor_sizes) == 1 tensor_size = tensor_sizes.pop() - page_size = kv_cache_spec.page_size_bytes + page_size = get_block_layout_page_size_bytes(kv_cache_spec) assert tensor_size % page_size == 0 num_blocks = tensor_size // page_size num_layers = len(kv_cache_config.kv_cache_tensors) From e33ff55fe857edae47cf915c25cc6d929363df16 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Sat, 2 May 2026 12:21:08 +0900 Subject: [PATCH 2/6] Handle mixed KV override for cudagraph profiling Signed-off-by: lesj0610 (cherry picked from commit 378322e014aeab09467a98e2348c04fd168d9c6b) --- tests/v1/core/test_kv_cache_utils.py | 36 ++++++++++++++++++++++++++++ vllm/v1/core/kv_cache_utils.py | 15 ++++++++++++ 2 files changed, 51 insertions(+) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 5cd4ec96d527..e277a49fcea5 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2296,6 +2296,42 @@ def test_request_constant_reservation_fails_closed_when_memory_exhausted(): ) +def test_request_constant_num_blocks_override_allows_minimal_config(): + vllm_config = make_request_constant_vllm_config(max_num_seqs=4) + vllm_config.cache_config.num_gpu_blocks_override = 1 + attention_spec = new_kv_cache_spec( + block_size=4, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ) + request_constant_spec = new_request_constant_spec( + num_speculative_blocks=1, + page_size_padded=16, + ) + request_constant_num_blocks = 4 * request_constant_spec.blocks_per_request + 1 + reserved_bytes = ( + request_constant_num_blocks * request_constant_spec.physical_page_size_bytes + ) + + config = kv_cache_utils.get_kv_cache_config_from_groups( + vllm_config, + [ + KVCacheGroupSpec(["attn"], attention_spec), + KVCacheGroupSpec(["state"], request_constant_spec), + ], + available_memory=0, + ) + + assert config.num_blocks == 1 + assert config.pool_configs[0].num_blocks == 1 + assert config.pool_configs[1].num_blocks == request_constant_num_blocks + assert config.kv_cache_tensors == [ + KVCacheTensor(size=attention_spec.page_size_bytes, shared_by=["attn"]), + KVCacheTensor(size=reserved_bytes, shared_by=["state"]), + ] + + def test_generate_uniform_type_kv_cache_specs(): # All layers are full attention, can be merged kv_cache_specs = { diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index abb3c571792a..07f033684cb8 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1375,6 +1375,21 @@ def _get_kv_cache_config_mixed_memory_model( request_constant_pool_specs: list[tuple[int, KVCachePoolConfig]] = [] reserved_bytes = _get_request_constant_reserved_bytes(vllm_config, kv_cache_groups) + + # CUDA graph memory profiling temporarily sets num_gpu_blocks_override and + # asks for a minimal KV cache with available_memory=0. The single-pool + # TOKEN_PROPORTIONAL path already honors that override after deriving an + # initial block count from available_memory. Mirror that behavior here so + # mixed-memory configs do not fail before the token pool gets a chance to + # apply the override. REQUEST_CONSTANT pool sizes remain deterministic. + override = vllm_config.cache_config.num_gpu_blocks_override + if override is not None and token_group_ids: + token_groups = [kv_cache_groups[group_id] for group_id in token_group_ids] + available_memory = max( + available_memory, + reserved_bytes + override * _pool_bytes_per_block(token_groups), + ) + next_pool_id = 1 if token_group_ids else 0 for group_id in request_constant_group_ids: group = kv_cache_groups[group_id] From 47af837e7e2d1ad8c1af431462a3b64ee0009758 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Sat, 2 May 2026 16:31:11 +0900 Subject: [PATCH 3/6] Fix request-constant KV review edge cases Co-authored-by: OpenAI Codex Signed-off-by: lesj0610 --- tests/v1/core/test_kv_cache_utils.py | 37 ++++++++++++++++++++++++++++ tests/v1/core/test_prefix_caching.py | 36 +++++++++++++++++++++++++++ vllm/v1/core/block_pool.py | 8 ++++-- vllm/v1/core/kv_cache_manager.py | 2 +- vllm/v1/core/kv_cache_utils.py | 25 +++++++++++++------ 5 files changed, 98 insertions(+), 10 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index e277a49fcea5..a34f7a07831d 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2332,6 +2332,43 @@ def test_request_constant_num_blocks_override_allows_minimal_config(): ] +def test_request_constant_only_num_blocks_override_allows_minimal_config(): + vllm_config = make_request_constant_vllm_config(max_num_seqs=4) + vllm_config.cache_config.num_gpu_blocks_override = 1 + request_constant_spec = new_request_constant_spec( + num_speculative_blocks=1, + page_size_padded=16, + ) + request_constant_num_blocks = 4 * request_constant_spec.blocks_per_request + 1 + reserved_bytes = ( + request_constant_num_blocks * request_constant_spec.physical_page_size_bytes + ) + + config = kv_cache_utils.get_kv_cache_config_from_groups( + vllm_config, + [KVCacheGroupSpec(["state"], request_constant_spec)], + available_memory=0, + ) + + assert config.num_blocks == request_constant_num_blocks + assert config.group_to_pool_id == (0,) + assert config.pool_configs == ( + KVCachePoolConfig( + pool_id=0, + memory_model=MemoryModel.REQUEST_CONSTANT, + group_ids=(0,), + num_blocks=request_constant_num_blocks, + accounting_page_size_bytes=( + request_constant_spec.accounting_page_size_bytes + ), + physical_page_size_bytes=request_constant_spec.physical_page_size_bytes, + ), + ) + assert config.kv_cache_tensors == [ + KVCacheTensor(size=reserved_bytes, shared_by=["state"]), + ] + + def test_generate_uniform_type_kv_cache_specs(): # All layers are full attention, can be merged kv_cache_specs = { diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index b21093dee721..ef7f0b93e438 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1992,6 +1992,42 @@ def test_kv_cache_events(blocks_to_cache: int): assert len(manager.block_pool.cached_block_hash_to_block) == 0 +def test_request_constant_only_kv_cache_events_noop(): + block_size = 4 + mamba_spec = MambaSpec( + block_size=block_size, + shapes=((1,),), + dtypes=(torch.float32,), + mamba_cache_mode="none", + ) + kv_cache_config = KVCacheConfig( + num_blocks=3, + kv_cache_tensors=[], + kv_cache_groups=[KVCacheGroupSpec(["mamba"], mamba_spec)], + pool_configs=( + KVCachePoolConfig( + pool_id=0, + memory_model=MemoryModel.REQUEST_CONSTANT, + group_ids=(0,), + num_blocks=3, + accounting_page_size_bytes=mamba_spec.accounting_page_size_bytes, + physical_page_size_bytes=mamba_spec.physical_page_size_bytes, + ), + ), + group_to_pool_id=(0,), + ) + + manager = KVCacheManager( + kv_cache_config, + max_model_len=8192, + enable_caching=False, + enable_kv_cache_events=True, + hash_block_size=block_size, + ) + + assert manager.take_events() == [] + + def test_null_parent_block_hash(): block_size = 1 num_cached_blocks = 2 diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 638944c0a14d..7ffc7bac5f7c 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -50,6 +50,8 @@ def get_num_free_blocks(self) -> int: ... def get_usage(self) -> float: ... + def take_events(self) -> list[KVCacheEvent]: ... + class CacheableBlockPoolProtocol(BlockPoolProtocol, Protocol): """Block-pool contract for pools that also support prefix caching.""" @@ -74,8 +76,6 @@ def evict_blocks(self, block_ids: set[int]) -> None: ... def reset_prefix_cache(self) -> bool: ... - def take_events(self) -> list[KVCacheEvent]: ... - class CompactBlockPool: """Compact allocation-only pool for request-constant KV-cache specs. @@ -139,6 +139,10 @@ def get_usage(self) -> float: return 0 return 1.0 - (self.get_num_free_blocks() / self._num_allocatable) + def take_events(self) -> list[KVCacheEvent]: + """Compact pools do not emit prefix-cache events.""" + return [] + class BlockHashToBlockMap: """ diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 3bf3dcbe6fc1..c1d7ced1095a 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -539,7 +539,7 @@ def take_events(self) -> list[KVCacheEvent]: Returns: A list of KV cache events. """ - return self._cacheable_block_pool().take_events() + return self.block_pool.take_events() def get_blocks(self, request_id: str) -> KVCacheBlocks: """Get the blocks of a request.""" diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 07f033684cb8..e0a69efc465c 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1383,12 +1383,15 @@ def _get_kv_cache_config_mixed_memory_model( # mixed-memory configs do not fail before the token pool gets a chance to # apply the override. REQUEST_CONSTANT pool sizes remain deterministic. override = vllm_config.cache_config.num_gpu_blocks_override - if override is not None and token_group_ids: - token_groups = [kv_cache_groups[group_id] for group_id in token_group_ids] - available_memory = max( - available_memory, - reserved_bytes + override * _pool_bytes_per_block(token_groups), - ) + if override is not None: + if token_group_ids: + token_groups = [kv_cache_groups[group_id] for group_id in token_group_ids] + available_memory = max( + available_memory, + reserved_bytes + override * _pool_bytes_per_block(token_groups), + ) + else: + available_memory = max(available_memory, reserved_bytes) next_pool_id = 1 if token_group_ids else 0 for group_id in request_constant_group_ids: @@ -1406,7 +1409,12 @@ def _get_kv_cache_config_mixed_memory_model( request_constant_pool_specs.append((group_id, pool_config)) next_pool_id += 1 - if reserved_bytes >= available_memory: + reservation_exhausts_memory = ( + reserved_bytes >= available_memory + if token_group_ids + else reserved_bytes > available_memory + ) + if reservation_exhausts_memory: raise ValueError( "REQUEST_CONSTANT KV cache reservation " f"({format_gib(reserved_bytes)} GiB) is not smaller than the " @@ -2326,6 +2334,9 @@ def _normalize_kv_cache_config_num_blocks( pool_id = _get_tensor_pool_id(kv_cache_config, tensor) old_num_blocks = old_pool_num_blocks[pool_id] new_num_blocks = normalized_pool_num_blocks[pool_id] + assert old_num_blocks > 0, ( + "KV cache pool num_blocks includes the null block and must be positive." + ) if old_num_blocks == new_num_blocks: continue assert tensor.size % old_num_blocks == 0 From 291b0638ab2830343d3ef1aff1c121b69edcad12 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Sat, 2 May 2026 19:23:02 +0900 Subject: [PATCH 4/6] Restore non-divisible KV page-size test expectation Keep the existing fail-closed behavior for hybrid specs whose page sizes cannot be aligned by block-size adjustment. Co-authored-by: OpenAI Codex Signed-off-by: lesj0610 --- tests/v1/core/test_kv_cache_utils.py | 29 +++++----------------------- 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index a34f7a07831d..b353fda1ddaa 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1809,35 +1809,16 @@ def test_get_kv_cache_config_one_worker(): ], ) - # Different hidden size and different type that cannot be aligned by using - # different block size. This is supported by padding the smaller page. + # different hidden size that cannot be aligned by using different block size kv_cache_specs_hybrid = { "layer_1": new_kv_cache_spec(head_size=64), "layer_2": new_sliding_window_spec(head_size=96), } - kv_cache_config_hybrid = get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] - )[0] - assert kv_cache_config_hybrid == KVCacheConfig( - num_blocks=42, - kv_cache_tensors=[ - KVCacheTensor( - size=mem_per_block_per_layer * 3 // 2 * 42, - shared_by=["layer_1", "layer_2"], - ), - ], - kv_cache_groups=[ - KVCacheGroupSpec( - ["layer_1"], - new_kv_cache_spec( - head_size=64, - page_size_padded=mem_per_block_per_layer * 3 // 2, - ), - ), - KVCacheGroupSpec(["layer_2"], new_sliding_window_spec(head_size=96)), - ], - ) + with pytest.raises(NotImplementedError): + get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] # Test num_gpu_blocks_override vllm_config.cache_config.num_gpu_blocks_override = 16 From 1ba157148cdf92cf4ee82df87c3801289e63768d Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Sat, 2 May 2026 21:29:50 +0900 Subject: [PATCH 5/6] Allow full cudagraph with request-constant KV pools Validate request-constant pool capacity with max_num_seqs instead of rejecting full CUDA graph capture outright. Co-authored-by: OpenAI Codex Signed-off-by: lesj0610 --- tests/v1/core/test_kv_cache_utils.py | 85 +++++++++++++++++++++++++--- vllm/config/compilation.py | 60 ++++++++++++++++---- 2 files changed, 127 insertions(+), 18 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index b353fda1ddaa..626034387cd1 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2084,8 +2084,12 @@ def test_token_proportional_capacity_ignores_request_constant_pool(): assert get_max_concurrency_for_kv_cache_config(vllm_config, config) == 4 -def test_request_constant_mamba_full_cudagraph_fails_closed(): - vllm_config = make_request_constant_vllm_config(max_num_seqs=4) +def _make_request_constant_mamba_cudagraph_config( + max_num_seqs: int = 4, + mamba_cache_mode: str = "none", + num_speculative_blocks: int = 2, +) -> KVCacheConfig: + vllm_config = make_request_constant_vllm_config(max_num_seqs=max_num_seqs) attention_spec = new_kv_cache_spec( block_size=4, num_kv_heads=1, @@ -2096,27 +2100,94 @@ def test_request_constant_mamba_full_cudagraph_fails_closed(): block_size=4, shapes=((4,),), dtypes=(torch.float32,), - mamba_cache_mode="none", + mamba_cache_mode=mamba_cache_mode, + num_speculative_blocks=num_speculative_blocks, page_size_padded=attention_spec.page_size_bytes, ) - mamba_num_blocks = 4 * mamba_spec.blocks_per_request + 1 + mamba_num_blocks = max_num_seqs * mamba_spec.blocks_per_request + 1 mamba_reserved_bytes = mamba_num_blocks * mamba_spec.physical_page_size_bytes - kv_cache_config = get_kv_cache_configs( + return get_kv_cache_configs( vllm_config, [{"attn": attention_spec, "mamba": mamba_spec}], [mamba_reserved_bytes + attention_spec.page_size_bytes * 16], )[0] + + +def test_request_constant_mamba_full_cudagraph_uses_pool_capacity(): + kv_cache_config = _make_request_constant_mamba_cudagraph_config() + compilation_config = CompilationConfig(cudagraph_mode=CUDAGraphMode.FULL) + + cudagraph_mode = compilation_config.resolve_cudagraph_mode_and_sizes( + min_cg_support=AttentionCGSupport.ALWAYS, + min_cg_attn_backend="test", + kv_cache_config=kv_cache_config, + max_num_reqs=4, + ) + + assert cudagraph_mode == CUDAGraphMode.FULL + + +def test_request_constant_mamba_full_cudagraph_rejects_small_pool(): + kv_cache_config = _make_request_constant_mamba_cudagraph_config() + compilation_config = CompilationConfig(cudagraph_mode=CUDAGraphMode.FULL) + + with pytest.raises( + ValueError, + match="REQUEST_CONSTANT KV cache blocks", + ): + compilation_config.resolve_cudagraph_mode_and_sizes( + min_cg_support=AttentionCGSupport.ALWAYS, + min_cg_attn_backend="test", + kv_cache_config=kv_cache_config, + max_num_reqs=5, + ) + + +def test_request_constant_mamba_full_cudagraph_align_uses_blocks_per_request(): + kv_cache_config = _make_request_constant_mamba_cudagraph_config( + mamba_cache_mode="align", + num_speculative_blocks=1, + ) + compilation_config = CompilationConfig(cudagraph_mode=CUDAGraphMode.FULL) + + cudagraph_mode = compilation_config.resolve_cudagraph_mode_and_sizes( + min_cg_support=AttentionCGSupport.ALWAYS, + min_cg_attn_backend="test", + kv_cache_config=kv_cache_config, + max_num_reqs=4, + ) + + assert cudagraph_mode == CUDAGraphMode.FULL + + +def test_request_constant_mamba_full_cudagraph_skips_profiling_capacity(): + kv_cache_config = _make_request_constant_mamba_cudagraph_config() + compilation_config = CompilationConfig(cudagraph_mode=CUDAGraphMode.FULL) + + cudagraph_mode = compilation_config.resolve_cudagraph_mode_and_sizes( + min_cg_support=AttentionCGSupport.ALWAYS, + min_cg_attn_backend="test", + kv_cache_config=kv_cache_config, + max_num_reqs=5, + is_profiling=True, + ) + + assert cudagraph_mode == CUDAGraphMode.FULL + + +def test_request_constant_mamba_full_cudagraph_requires_max_num_reqs(): + kv_cache_config = _make_request_constant_mamba_cudagraph_config() compilation_config = CompilationConfig(cudagraph_mode=CUDAGraphMode.FULL) with pytest.raises( ValueError, - match="Full CUDA graph capture with REQUEST_CONSTANT KV cache", + match="requires max_num_seqs for capacity validation", ): compilation_config.resolve_cudagraph_mode_and_sizes( min_cg_support=AttentionCGSupport.ALWAYS, min_cg_attn_backend="test", kv_cache_config=kv_cache_config, - max_num_reqs=4, + max_num_reqs=None, ) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 343f25c6c197..0ac55376903d 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -1298,6 +1298,43 @@ def is_custom_op_enabled(self, op: str) -> bool: assert "none" in self.custom_ops return f"+{op}" in self.custom_ops + @staticmethod + def _has_request_constant_kv_pools(kv_cache_config: "KVCacheConfig") -> bool: + from vllm.v1.kv_cache_interface import MemoryModel + + return any( + pool.memory_model == MemoryModel.REQUEST_CONSTANT + for pool in kv_cache_config.pool_configs + ) + + @staticmethod + def _validate_request_constant_cudagraph_capacity( + kv_cache_config: "KVCacheConfig", + max_num_reqs: int, + ) -> None: + from vllm.v1.kv_cache_interface import MemoryModel + + for pool in kv_cache_config.pool_configs: + if pool.memory_model != MemoryModel.REQUEST_CONSTANT: + continue + + blocks_per_request = 0 + for group_id in pool.group_ids: + spec = kv_cache_config.kv_cache_groups[group_id].kv_cache_spec + blocks_per_request += spec.blocks_per_request + required_blocks = max_num_reqs * blocks_per_request + usable_blocks = pool.num_blocks - 1 + if required_blocks <= usable_blocks: + continue + + raise ValueError( + f"max_num_seqs ({max_num_reqs}) requires " + f"{required_blocks} REQUEST_CONSTANT KV cache blocks for " + f"pool {pool.pool_id}, but only {usable_blocks} are available. " + "Full CUDA graph capture cannot proceed. Please lower " + "max_num_seqs or increase gpu_memory_utilization." + ) + def resolve_cudagraph_mode_and_sizes( self, min_cg_support: "AttentionCGSupport", @@ -1417,25 +1454,25 @@ def resolve_cudagraph_mode_and_sizes( tensor_parallel_size, ) + has_request_constant_pools = ( + kv_cache_config is not None + and self._has_request_constant_kv_pools(kv_cache_config) + ) + if ( kv_cache_config is not None and cudagraph_mode.has_full_cudagraphs() and not is_profiling - and kv_cache_config.has_mamba_layers + and has_request_constant_pools ): - from vllm.v1.kv_cache_interface import MemoryModel - - if any( - pool.memory_model == MemoryModel.REQUEST_CONSTANT - for pool in kv_cache_config.pool_configs - ): + if max_num_reqs is None: raise ValueError( "Full CUDA graph capture with REQUEST_CONSTANT KV cache " - "(Mamba in 'none' or 'align' mode) is not yet supported. " - "Either disable cudagraph capture (e.g., enforce_eager=True) " - "or set mamba_cache_mode='all' to use the legacy " - "shared-pool path." + "requires max_num_seqs for capacity validation." ) + self._validate_request_constant_cudagraph_capacity( + kv_cache_config, max_num_reqs + ) # For Mamba models with FULL decode cudagraphs, each decode # sequence needs one Mamba cache block. The decode cudagraph @@ -1450,6 +1487,7 @@ def resolve_cudagraph_mode_and_sizes( and cudagraph_mode.has_full_cudagraphs() and not is_profiling and kv_cache_config.has_mamba_layers + and not has_request_constant_pools and max_num_reqs > kv_cache_config.num_blocks ): raise ValueError( From e52ba7a087842aed321cb00bb89f8ae67a269dca Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Sun, 3 May 2026 13:25:27 +0900 Subject: [PATCH 6/6] Fix legacy KV cache metadata for connector tests Signed-off-by: lesj0610 --- tests/v1/core/test_kv_cache_utils.py | 84 ++++++++++++++++++++++++++++ tests/v1/kv_connector/unit/utils.py | 3 + vllm/v1/kv_cache_interface.py | 30 +++++----- 3 files changed, 104 insertions(+), 13 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 626034387cd1..687d88d30460 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -244,6 +244,90 @@ def assert_legacy_single_pool_metadata(config: KVCacheConfig) -> None: assert pool_config.physical_page_size_bytes == accounting_page_size +def test_legacy_pool_metadata_keeps_mixed_memory_models_shared(): + attention_spec = new_kv_cache_spec( + block_size=4, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ) + mamba_spec = new_mamba_spec( + block_size=4, + shapes=((4,),), + dtypes=(torch.float32,), + mamba_cache_mode="none", + ) + + config = KVCacheConfig( + num_blocks=7, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(["attn"], attention_spec), + KVCacheGroupSpec(["mamba"], mamba_spec), + ], + ) + + assert config.group_to_pool_id == (0, 0) + assert config.pool_configs == ( + KVCachePoolConfig( + pool_id=0, + memory_model=MemoryModel.TOKEN_PROPORTIONAL, + group_ids=(0, 1), + num_blocks=7, + accounting_page_size_bytes=max( + attention_spec.accounting_page_size_bytes, + mamba_spec.accounting_page_size_bytes, + ), + physical_page_size_bytes=max( + attention_spec.accounting_page_size_bytes, + mamba_spec.accounting_page_size_bytes, + ), + ), + ) + + +def test_legacy_pool_metadata_keeps_different_page_sizes_shared(): + small_spec = FullAttentionSpec( + block_size=12, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ) + large_spec = FullAttentionSpec( + block_size=16, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ) + + config = KVCacheConfig( + num_blocks=11, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(["small"], small_spec), + KVCacheGroupSpec(["large"], large_spec), + ], + ) + + assert config.group_to_pool_id == (0, 0) + assert config.pool_configs == ( + KVCachePoolConfig( + pool_id=0, + memory_model=MemoryModel.TOKEN_PROPORTIONAL, + group_ids=(0, 1), + num_blocks=11, + accounting_page_size_bytes=max( + small_spec.accounting_page_size_bytes, + large_spec.accounting_page_size_bytes, + ), + physical_page_size_bytes=max( + small_spec.accounting_page_size_bytes, + large_spec.accounting_page_size_bytes, + ), + ), + ) + + @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_none_hash(monkeypatch, hash_fn): import vllm.v1.core.kv_cache_utils diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 8e4e1cae0676..00ddbc77ac7c 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -469,6 +469,9 @@ def make_kv_cache_config( block_size=block_size, shapes=((16,), (16,)), dtypes=(torch.float16,), + # These connector tests exercise the legacy shared-pool + # prefix-cache path, not request-constant Mamba allocation. + mamba_cache_mode="all", ), ) ) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index ffa20e36c908..d36620f7106b 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -871,8 +871,8 @@ def __post_init__(self) -> None: return if not self.pool_configs and not self.group_to_pool_id: - pool_config, group_to_pool_id = self._make_legacy_pool_metadata() - self.pool_configs = (pool_config,) + pool_configs, group_to_pool_id = self._make_legacy_pool_metadata() + self.pool_configs = pool_configs self.group_to_pool_id = group_to_pool_id return @@ -884,27 +884,31 @@ def __post_init__(self) -> None: def _make_legacy_pool_metadata( self, num_blocks: int | None = None, - ) -> tuple[KVCachePoolConfig, tuple[int, ...]]: + ) -> tuple[tuple[KVCachePoolConfig, ...], tuple[int, ...]]: + """Derive pool metadata for legacy direct ``KVCacheConfig`` callers. + + Production multi-pool configs should pass explicit pool metadata. This + fallback keeps the old single shared block-pool behavior. It cannot + compute request-constant compact capacity because it has no scheduler + context. + """ specs = [group.kv_cache_spec for group in self.kv_cache_groups] - memory_models = {spec.memory_model for spec in specs} - assert len(memory_models) == 1 - accounting_page_sizes = {spec.accounting_page_size_bytes for spec in specs} - assert len(accounting_page_sizes) == 1 + accounting_page_size = max(spec.accounting_page_size_bytes for spec in specs) physical_page_sizes = {spec.physical_page_size_bytes for spec in specs} if len(physical_page_sizes) == 1: pool_physical_page_size_bytes = physical_page_sizes.pop() else: - pool_physical_page_size_bytes = next(iter(accounting_page_sizes)) + pool_physical_page_size_bytes = accounting_page_size pool_config = KVCachePoolConfig( pool_id=0, - memory_model=memory_models.pop(), + memory_model=MemoryModel.TOKEN_PROPORTIONAL, group_ids=tuple(range(len(self.kv_cache_groups))), num_blocks=self.num_blocks if num_blocks is None else num_blocks, - accounting_page_size_bytes=accounting_page_sizes.pop(), + accounting_page_size_bytes=accounting_page_size, physical_page_size_bytes=pool_physical_page_size_bytes, ) - return pool_config, tuple(0 for _ in self.kv_cache_groups) + return (pool_config,), tuple(0 for _ in self.kv_cache_groups) def _legacy_num_blocks_pool_id(self) -> int | None: """Return the pool represented by the legacy ``num_blocks`` field.""" @@ -973,8 +977,8 @@ def refresh_legacy_pool_metadata(self) -> None: ): self._refresh_multi_pool_metadata() return - pool_config, group_to_pool_id = self._make_legacy_pool_metadata() - self.pool_configs = (pool_config,) + pool_configs, group_to_pool_id = self._make_legacy_pool_metadata() + self.pool_configs = pool_configs self.group_to_pool_id = group_to_pool_id @property