Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
make_block_hash_with_group_id,
)
from vllm.v1.kv_cache_interface import (
CrossAttentionSpec,
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
Expand Down Expand Up @@ -120,6 +121,17 @@ def new_sliding_window_spec(
)


def new_cross_attention_spec(
block_size=16, num_kv_heads=2, head_size=64, dtype=torch.float32
):
return CrossAttentionSpec(
block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
)


@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_none_hash(monkeypatch, hash_fn):
import vllm.v1.core.kv_cache_utils
Expand Down Expand Up @@ -1406,6 +1418,64 @@ def test_get_kv_cache_config_one_worker():
)


def test_check_enough_kv_cache_memory_respects_num_gpu_blocks_override():
"""If num_gpu_blocks_override is set too small, engine init should fail.

This guards against allowing configurations where the effective number of
KV blocks is not enough to hold a single request with
`model_config.max_model_len` tokens.
"""
# max_model_len requires more than one block (block_size defaults to 16)
model_config = ModelConfig(max_model_len=32)
vllm_config = VllmConfig(model_config=model_config)

# Single worker with two identical full-attention layers
kv_cache_specs = {
"layer_1": new_kv_cache_spec(),
"layer_2": new_kv_cache_spec(),
}

# Force only one KV block available via override
vllm_config.cache_config.num_gpu_blocks_override = 1

# Set available memory very large so memory-based checks would pass
# without the override. Use per-block total to scale.
per_block_total = sum(spec.page_size_bytes for spec in kv_cache_specs.values())
available_memory = per_block_total * 100

# With only one block, a 32-token request (2 blocks) cannot fit; expect error
with pytest.raises(ValueError):
get_kv_cache_configs(vllm_config, [kv_cache_specs], [available_memory])


def test_override_must_cover_worst_layer_blocks_in_heterogeneous_model():
"""Override must be >= the maximum per-layer required blocks.

Create a heterogeneous spec where cross-attention needs more blocks than
decoder self-attention (due to default max_num_encoder_input_tokens=2048),
and verify that an override between the two is rejected.
"""
# Full-attn needs ceil(1024/16)=64 blocks; cross-attn needs ceil(2048/16)=128.
model_config = ModelConfig(max_model_len=1024)
vllm_config = VllmConfig(model_config=model_config)

kv_cache_specs = {
"decoder_self": new_kv_cache_spec(block_size=16),
"cross": new_cross_attention_spec(block_size=16),
}

# Set override below the worst-layer requirement (96 < 128)
vllm_config.cache_config.num_gpu_blocks_override = 96

# Use large available memory so raw memory checks would pass
# without the override capping taking effect.
per_block_total = sum(spec.page_size_bytes for spec in kv_cache_specs.values())
available_memory = per_block_total * 100

with pytest.raises(ValueError):
get_kv_cache_configs(vllm_config, [kv_cache_specs], [available_memory])


def test_get_kv_cache_configs_attention_free():
kv_cache_specs: dict[str, KVCacheSpec] = {}
vllm_config = VllmConfig(model_config=ModelConfig(max_model_len=16))
Expand Down
57 changes: 53 additions & 4 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,10 +664,44 @@ def check_enough_kv_cache_memory(
max_model_len = vllm_config.model_config.max_model_len
needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values())

if needed_memory > available_memory:
# Respect explicit override of number of KV blocks by:
# 1) Validating that the override covers the per-layer required blocks.
# 2) Capping the effective available memory based on the override.
effective_available_memory = available_memory
override_blocks = vllm_config.cache_config.num_gpu_blocks_override
if override_blocks is not None:
# Compute the minimum blocks required among layers for a single request.
# This is ceil(layer_needed_bytes / layer_page_size_bytes).
from vllm.utils import cdiv

per_layer_required_blocks = [
cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes)
for spec in kv_cache_spec.values()
]
max_required_blocks = (
max(per_layer_required_blocks) if per_layer_required_blocks else 0
)

if override_blocks < max_required_blocks:
raise ValueError(
"num_gpu_blocks_override is too small to serve at least one "
"request with the model's max seq len. "
f"Required blocks: {max_required_blocks}, "
f"but got num_gpu_blocks_override={override_blocks}. "
"Increase num_gpu_blocks_override, decrease max_model_len, or "
"increase gpu_memory_utilization."
)

# Cap available memory by the number of blocks allowed via override.
per_block_total = sum(spec.page_size_bytes for spec in kv_cache_spec.values())
effective_available_memory = min(
available_memory, per_block_total * override_blocks
)

if needed_memory > effective_available_memory:
# Estimate the maximum model length that can fit in the available memory
estimated_max_len = estimate_max_model_len(
vllm_config, kv_cache_spec, available_memory
vllm_config, kv_cache_spec, effective_available_memory
)
Comment on lines 703 to 705
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function estimate_max_model_len modifies the vllm_config.model_config.max_model_len attribute as a side effect of its binary search implementation. While this is not currently causing a bug because this code path always raises an exception, it is a latent bug that could cause issues in the future if this function is called in a context that doesn't terminate.

A function should not have hidden side effects on its arguments. It would be best to refactor estimate_max_model_len to not modify vllm_config, for example by restoring the original value before returning or by working on a copy.

Since the definition of estimate_max_model_len is not in this diff, I'm pointing this out here at the call site. A fix could look like this inside estimate_max_model_len:

def estimate_max_model_len(...):
    original_max_len = vllm_config.model_config.max_model_len
    try:
        # ... existing logic ...
        return result
    finally:
        vllm_config.model_config.max_model_len = original_max_len

estimated_msg = ""
if estimated_max_len > 0:
Expand All @@ -676,11 +710,26 @@ def check_enough_kv_cache_memory(
f"the estimated maximum model length is {estimated_max_len}."
)

# Tailor error message if override is limiting effective capacity.
if (
override_blocks is not None
and effective_available_memory < available_memory
):
extra = (
f"effective available KV cache memory "
f"({effective_available_memory / GiB_bytes:.2f} GiB) "
f"with num_gpu_blocks_override={override_blocks}"
)
else:
extra = (
f"available KV cache memory "
f"({effective_available_memory / GiB_bytes:.2f} GiB)"
)

raise ValueError(
f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory / GiB_bytes:.2f} GiB KV "
f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory / GiB_bytes:.2f} GiB). "
f"cache is needed, which is larger than the {extra}. "
f"{estimated_msg} "
f"Try increasing `gpu_memory_utilization` or decreasing "
f"`max_model_len` when initializing the engine."
Expand Down