Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 0 additions & 41 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
KVCacheGroupSpec,
KVCacheSpec,
KVCacheTensor,
MambaSpec,
MLAAttentionSpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
Expand Down Expand Up @@ -158,24 +157,6 @@ def new_chunked_local_attention_spec(
)


def new_mamba_spec(
block_size=16,
shapes=((2, 512), (3, 32, 32)),
dtypes=(torch.float32, torch.float32),
num_speculative_blocks=2,
mamba_cache_mode="none",
page_size_padded=None,
):
return MambaSpec(
block_size=block_size,
shapes=shapes,
dtypes=dtypes,
page_size_padded=page_size_padded,
mamba_cache_mode=mamba_cache_mode,
num_speculative_blocks=num_speculative_blocks,
)


@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_none_hash(monkeypatch, hash_fn):
import vllm.v1.core.kv_cache_utils
Expand Down Expand Up @@ -2029,28 +2010,6 @@ def test_auto_fit_max_model_len():
assert vllm_config.model_config.max_model_len > 0


def test_auto_fit_max_model_len_with_hybrid():
"""Test that auto-fit works with hybrid KV cache specs."""
# Create config with original_max_model_len=-1 to trigger auto-fit
model_config = ModelConfig(max_model_len=8192)
# Simulate the user passing -1 by setting original_max_model_len
model_config.original_max_model_len = -1
vllm_config = VllmConfig(model_config=model_config)

mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 # 16KB per block per layer
gamma = 2
kv_cache_specs = {
"layer_1": new_mamba_spec(num_speculative_blocks=gamma),
"layer_2": new_kv_cache_spec(),
}

available_memory = mem_per_block_per_layer * (1024 // 16 + 1 + gamma)
_kv_cache_configs = get_kv_cache_configs(
vllm_config, [kv_cache_specs], [available_memory]
)
assert vllm_config.model_config.max_model_len == 1024


def test_auto_fit_max_model_len_not_triggered():
"""Test that auto-fit is not triggered when original_max_model_len is not -1."""
model_config = ModelConfig(max_model_len=16)
Expand Down
6 changes: 2 additions & 4 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,10 +1356,8 @@ def _max_memory_usage_bytes_from_groups(
page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups]
)
blocks_needed = sum(
cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size)
for group in kv_cache_groups
)
any_spec = kv_cache_groups[0].kv_cache_spec
blocks_needed = cdiv(any_spec.max_memory_usage_bytes(vllm_config), page_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The reverted logic for calculating blocks_needed uses any_spec = kv_cache_groups[0].kv_cache_spec and then cdiv(any_spec.max_memory_usage_bytes(vllm_config), page_size). This assumes that the max_memory_usage_bytes is uniform across all kv_cache_spec objects within kv_cache_groups for the "General case" (i.e., when not UniformTypeKVCacheSpecs).

However, different KVCacheSpec types (e.g., FullAttentionSpec vs. SlidingWindowSpec) can have different max_memory_usage_bytes calculations, even if their page_size_bytes are unified. By only considering kv_cache_groups[0].kv_cache_spec, this calculation might underestimate the total blocks needed if subsequent groups have higher memory requirements. This could lead to insufficient memory allocation and runtime failures.

To correctly account for all groups, blocks_needed should be derived from the maximum memory usage among all individual kv_cache_spec objects in the groups.

Suggested change
blocks_needed = cdiv(any_spec.max_memory_usage_bytes(vllm_config), page_size)
blocks_needed = cdiv(max(group.kv_cache_spec.max_memory_usage_bytes(vllm_config) for group in kv_cache_groups), page_size)


return group_size * page_size * blocks_needed

Expand Down
Loading