Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
KVCacheGroupSpec,
KVCacheSpec,
KVCacheTensor,
MambaSpec,
MLAAttentionSpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
Expand Down Expand Up @@ -157,6 +158,24 @@ def new_chunked_local_attention_spec(
)


def new_mamba_spec(
block_size=16,
shapes=((2, 512), (3, 32, 32)),
dtypes=(torch.float32, torch.float32),
num_speculative_blocks=2,
mamba_cache_mode="none",
page_size_padded=None,
):
return MambaSpec(
block_size=block_size,
shapes=shapes,
dtypes=dtypes,
page_size_padded=page_size_padded,
mamba_cache_mode=mamba_cache_mode,
num_speculative_blocks=num_speculative_blocks,
)


@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_none_hash(monkeypatch, hash_fn):
import vllm.v1.core.kv_cache_utils
Expand Down Expand Up @@ -2010,6 +2029,28 @@ def test_auto_fit_max_model_len():
assert vllm_config.model_config.max_model_len > 0


def test_auto_fit_max_model_len_with_hybrid():
"""Test that auto-fit works with hybrid KV cache specs."""
# Create config with original_max_model_len=-1 to trigger auto-fit
model_config = ModelConfig(max_model_len=8192)
# Simulate the user passing -1 by setting original_max_model_len
model_config.original_max_model_len = -1
vllm_config = VllmConfig(model_config=model_config)

mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 # 16KB per block per layer
gamma = 2
kv_cache_specs = {
"layer_1": new_mamba_spec(num_speculative_blocks=gamma),
"layer_2": new_kv_cache_spec(),
}

available_memory = mem_per_block_per_layer * (1024 // 16 + 1 + gamma)
_kv_cache_configs = get_kv_cache_configs(
vllm_config, [kv_cache_specs], [available_memory]
)
assert vllm_config.model_config.max_model_len == 1024


def test_auto_fit_max_model_len_not_triggered():
"""Test that auto-fit is not triggered when original_max_model_len is not -1."""
model_config = ModelConfig(max_model_len=16)
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,8 +1356,10 @@ def _max_memory_usage_bytes_from_groups(
page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups]
)
any_spec = kv_cache_groups[0].kv_cache_spec
blocks_needed = cdiv(any_spec.max_memory_usage_bytes(vllm_config), page_size)
blocks_needed = sum(
cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size)
for group in kv_cache_groups
)
Comment on lines +1359 to +1362
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a quick test I am getting OOM with #37124, but not with #36030. So the highlighted lines seems to be needed @swtb3. This on Blackwell + Qwen3.5-27B-FP8

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ive done some digging on the cause of the OOM. I think that to get the proper allocation for Qwen3.5 will have me back to the drawing board. It may not be as simple as I first thought. I would say, if this PR is ready and tested then go for it. I will rebase my PR on top and continue figuring it out. If youve any thoughts on the OOM lets discuss over on #37124

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will use this PR for the time being, ping me when you have something you want me to test. Thank you!

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@repne new changes pushed, could you test? cheers!


return group_size * page_size * blocks_needed

Expand Down
Loading