diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 8153fed699fe..08463a2800c2 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -43,7 +43,6 @@ KVCacheGroupSpec, KVCacheSpec, KVCacheTensor, - MambaSpec, MLAAttentionSpec, SlidingWindowSpec, UniformTypeKVCacheSpecs, @@ -158,24 +157,6 @@ def new_chunked_local_attention_spec( ) -def new_mamba_spec( - block_size=16, - shapes=((2, 512), (3, 32, 32)), - dtypes=(torch.float32, torch.float32), - num_speculative_blocks=2, - mamba_cache_mode="none", - page_size_padded=None, -): - return MambaSpec( - block_size=block_size, - shapes=shapes, - dtypes=dtypes, - page_size_padded=page_size_padded, - mamba_cache_mode=mamba_cache_mode, - num_speculative_blocks=num_speculative_blocks, - ) - - @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_none_hash(monkeypatch, hash_fn): import vllm.v1.core.kv_cache_utils @@ -2029,28 +2010,6 @@ def test_auto_fit_max_model_len(): assert vllm_config.model_config.max_model_len > 0 -def test_auto_fit_max_model_len_with_hybrid(): - """Test that auto-fit works with hybrid KV cache specs.""" - # Create config with original_max_model_len=-1 to trigger auto-fit - model_config = ModelConfig(max_model_len=8192) - # Simulate the user passing -1 by setting original_max_model_len - model_config.original_max_model_len = -1 - vllm_config = VllmConfig(model_config=model_config) - - mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 # 16KB per block per layer - gamma = 2 - kv_cache_specs = { - "layer_1": new_mamba_spec(num_speculative_blocks=gamma), - "layer_2": new_kv_cache_spec(), - } - - available_memory = mem_per_block_per_layer * (1024 // 16 + 1 + gamma) - _kv_cache_configs = get_kv_cache_configs( - vllm_config, [kv_cache_specs], [available_memory] - ) - assert vllm_config.model_config.max_model_len == 1024 - - def test_auto_fit_max_model_len_not_triggered(): """Test that auto-fit is not triggered when original_max_model_len is not -1.""" model_config = ModelConfig(max_model_len=16) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 83ada05309f9..3da3d7e7bef7 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1356,10 +1356,8 @@ def _max_memory_usage_bytes_from_groups( page_size = get_uniform_page_size( [group.kv_cache_spec for group in kv_cache_groups] ) - blocks_needed = sum( - cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size) - for group in kv_cache_groups - ) + any_spec = kv_cache_groups[0].kv_cache_spec + blocks_needed = cdiv(any_spec.max_memory_usage_bytes(vllm_config), page_size) return group_size * page_size * blocks_needed