Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 69 additions & 11 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,25 +799,53 @@ def is_kv_cache_spec_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
return True


def get_max_concurrency_for_kv_cache_config(
def _blocks_per_request(
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
) -> float:
"""
Get the maximum concurrency for the given KV cache configuration.
) -> int:
"""Return number of blocks needed per request at max_model_len.

Note: the num_layer_per_group factor appears in both numerator and
denominator and cancels out, so the result is correct regardless of
whether page_size_bytes already includes all layers (as in
UniformTypeKVCacheSpecs) or is per-layer.
"""
num_layer_per_group = max(
len(group.layer_names) for group in kv_cache_config.kv_cache_groups
)
page_size = kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes
max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes(
vllm_config, (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups)
vllm_config,
(group.kv_cache_spec for group in kv_cache_config.kv_cache_groups),
)
memory_per_block = (
kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes
* num_layer_per_group
memory_per_block = page_size * num_layer_per_group
return cdiv(max_memory_usage_per_request, memory_per_block)
Comment on lines +802 to +821
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation for _blocks_per_request seems incorrect for models with multiple KV cache groups (e.g., hybrid models). It appears to return num_groups * cdiv(max_model_len, block_size) instead of just cdiv(max_model_len, block_size).

This will cause get_max_concurrency_for_kv_cache_config to underestimate the maximum concurrency by a factor of num_groups, and the new logging for needed_kv_bytes will overestimate the required memory by the same factor.

The number of blocks required for a sequence should be independent of the number of KV cache groups, as blocks from the pool are allocated per sequence, and each block from the pool serves all groups (via memory sharing across layers).

A simpler and more correct implementation would calculate the blocks needed per layer for a single sequence, assuming all groups share the same block size.

Suggested change
def _blocks_per_request(
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
) -> float:
"""
Get the maximum concurrency for the given KV cache configuration.
) -> int:
"""Return number of blocks needed per request at max_model_len.
Note: the num_layer_per_group factor appears in both numerator and
denominator and cancels out, so the result is correct regardless of
whether page_size_bytes already includes all layers (as in
UniformTypeKVCacheSpecs) or is per-layer.
"""
num_layer_per_group = max(
len(group.layer_names) for group in kv_cache_config.kv_cache_groups
)
page_size = kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes
max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes(
vllm_config, (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups)
vllm_config,
(group.kv_cache_spec for group in kv_cache_config.kv_cache_groups),
)
memory_per_block = (
kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes
* num_layer_per_group
memory_per_block = page_size * num_layer_per_group
return cdiv(max_memory_usage_per_request, memory_per_block)
def _blocks_per_request(
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
) -> int:
"""Return number of blocks needed per request at max_model_len.
Note: This assumes that all KV cache groups have the same block size.
"""
# All groups must have same block size. We take the spec from the first
# group as representative for block size and max memory usage calculation
# per layer.
spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
max_memory_per_layer = spec.max_memory_usage_bytes(vllm_config)
bytes_per_block_per_layer = spec.page_size_bytes
return cdiv(max_memory_per_layer, bytes_per_block_per_layer)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double-checked this path, and the current _blocks_per_request math is intentional for hybrid KV configs.
In vLLM, blocks are consumed per KV cache group for a request, so for hybrid models the request block demand is the sum across groups (e.g., full-attn blocks + sliding-window blocks), not just cdiv(max_model_len, block_size) from one representative group.
That is why _blocks_per_request computes:

  • numerator: num_layer_per_group * sum(group.max_memory_usage_bytes(...))
  • denominator: num_layer_per_group * page_size
    which simplifies to cdiv(sum_group_memory, page_size) (the num_layer_per_group factor cancels). This preserves existing concurrency semantics.

There is an existing test that reflects this behavior:
tests/v1/core/test_kv_cache_utils.py:1405 (kv_cache_config_hybrid_model) expects concurrency 3 for num_blocks=(1024 + 129) * 3, i.e., blocks/request is 1024 + 129, not 1024.

Also, the recent fix in this PR separates byte accounting for logging (_kv_cache_bytes_per_block) so UniformType no longer overcounts GiB, while keeping concurrency math unchanged.



def _kv_cache_bytes_per_block(kv_cache_config: KVCacheConfig) -> int:
"""Return the actual memory footprint of one block across all layers.

For UniformTypeKVCacheSpecs, page_size_bytes already sums across all
layers in the group. For other spec types, page_size_bytes is per-layer
and must be multiplied by the layer count.
"""
spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
if isinstance(spec, UniformTypeKVCacheSpecs):
return spec.page_size_bytes
num_layers = max(
len(g.layer_names) for g in kv_cache_config.kv_cache_groups
)
return spec.page_size_bytes * num_layers


def get_max_concurrency_for_kv_cache_config(
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
) -> float:
"""
Get the maximum concurrency for the given KV cache configuration.
"""
return kv_cache_config.num_blocks / _blocks_per_request(
vllm_config, kv_cache_config
)
num_block_per_request = cdiv(max_memory_usage_per_request, memory_per_block)
max_concurrency = kv_cache_config.num_blocks / num_block_per_request
return max_concurrency


def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int:
Expand Down Expand Up @@ -1328,6 +1356,36 @@ def _report_kv_cache_config(
scope="local",
)

# Log KV cache memory in GiB and workload capacity.
blocks_per_req = _blocks_per_request(vllm_config, kv_cache_config)
bytes_per_block = _kv_cache_bytes_per_block(kv_cache_config)
allocated_kv_bytes = kv_cache_config.num_blocks * bytes_per_block
logger.info_once(
"GPU KV cache memory: %s GiB (%d blocks)",
format_gib(allocated_kv_bytes),
kv_cache_config.num_blocks,
scope="local",
)

max_num_seqs = vllm_config.scheduler_config.max_num_seqs
needed_kv_bytes = blocks_per_req * max_num_seqs * bytes_per_block
logger.info_once(
"KV cache for %d seqs x %s tokens: %s GiB (allocated: %s GiB)",
max_num_seqs,
max_model_len_str,
format_gib(needed_kv_bytes),
format_gib(allocated_kv_bytes),
scope="local",
)
if max_concurrency < max_num_seqs:
logger.warning(
"KV cache can hold %d full-length sequences but "
"max_num_seqs is %d. "
"Sequences will queue when KV cache is full.",
int(max_concurrency),
max_num_seqs,
)


def _max_memory_usage_bytes_from_groups(
vllm_config: VllmConfig,
Expand Down
Loading