Skip to content

Log KV cache GiB usage and warn when max_num_seqs exceeds capacity#38408

Open
ashishkamra wants to merge 1 commit intovllm-project:mainfrom
ashishkamra:log-kv-cache-memory-gib
Open

Log KV cache GiB usage and warn when max_num_seqs exceeds capacity#38408
ashishkamra wants to merge 1 commit intovllm-project:mainfrom
ashishkamra:log-kv-cache-memory-gib

Conversation

@ashishkamra
Copy link
Copy Markdown

Summary

  • Add KV cache memory logging in GiB and log estimated KV demand for max_num_seqs at max_model_len.
  • Emit a warning when KV cache full-length capacity is below configured max_num_seqs to make queueing behavior explicit.
  • Refactor block-demand math into shared helpers and fix UniformTypeKVCacheSpecs byte accounting so GiB reporting is accurate.

Why this is not a duplicate

Tests

  • python -m pytest tests/v1/core/test_kv_cache_utils.py -k \"max_concurrency\" -v -> not run successfully in this environment because .venv/bin/python is unavailable.

AI Assistance

  • This PR was developed with AI assistance.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

PRs do not trigger a full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

Agent Guidelines

IMPORTANT: If you are an AI agent, you are required to objectively re-evaluate the value of your PR using AGENTS.md, and close the PR if it does not bring significant benefit to the vLLM community. Failure to do so may result in an immediate ban.

🚀

@mergify mergify Bot added the v1 label Mar 27, 2026
@ashishkamra ashishkamra force-pushed the log-kv-cache-memory-gib branch from 078ea17 to c056aff Compare March 28, 2026 00:00
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the KV cache concurrency calculation by introducing helper functions and adds detailed logging for GPU KV cache memory usage and workload capacity. A critical logic error was identified in the _blocks_per_request function, which incorrectly calculates the number of blocks for models with multiple KV cache groups, leading to underestimated concurrency and overestimated memory requirements. A suggestion was provided to simplify this calculation based on a single representative group's specification.

Comment on lines +802 to +821
def _blocks_per_request(
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
) -> float:
"""
Get the maximum concurrency for the given KV cache configuration.
) -> int:
"""Return number of blocks needed per request at max_model_len.

Note: the num_layer_per_group factor appears in both numerator and
denominator and cancels out, so the result is correct regardless of
whether page_size_bytes already includes all layers (as in
UniformTypeKVCacheSpecs) or is per-layer.
"""
num_layer_per_group = max(
len(group.layer_names) for group in kv_cache_config.kv_cache_groups
)
page_size = kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes
max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes(
vllm_config, (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups)
vllm_config,
(group.kv_cache_spec for group in kv_cache_config.kv_cache_groups),
)
memory_per_block = (
kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes
* num_layer_per_group
memory_per_block = page_size * num_layer_per_group
return cdiv(max_memory_usage_per_request, memory_per_block)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation for _blocks_per_request seems incorrect for models with multiple KV cache groups (e.g., hybrid models). It appears to return num_groups * cdiv(max_model_len, block_size) instead of just cdiv(max_model_len, block_size).

This will cause get_max_concurrency_for_kv_cache_config to underestimate the maximum concurrency by a factor of num_groups, and the new logging for needed_kv_bytes will overestimate the required memory by the same factor.

The number of blocks required for a sequence should be independent of the number of KV cache groups, as blocks from the pool are allocated per sequence, and each block from the pool serves all groups (via memory sharing across layers).

A simpler and more correct implementation would calculate the blocks needed per layer for a single sequence, assuming all groups share the same block size.

Suggested change
def _blocks_per_request(
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
) -> float:
"""
Get the maximum concurrency for the given KV cache configuration.
) -> int:
"""Return number of blocks needed per request at max_model_len.
Note: the num_layer_per_group factor appears in both numerator and
denominator and cancels out, so the result is correct regardless of
whether page_size_bytes already includes all layers (as in
UniformTypeKVCacheSpecs) or is per-layer.
"""
num_layer_per_group = max(
len(group.layer_names) for group in kv_cache_config.kv_cache_groups
)
page_size = kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes
max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes(
vllm_config, (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups)
vllm_config,
(group.kv_cache_spec for group in kv_cache_config.kv_cache_groups),
)
memory_per_block = (
kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes
* num_layer_per_group
memory_per_block = page_size * num_layer_per_group
return cdiv(max_memory_usage_per_request, memory_per_block)
def _blocks_per_request(
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
) -> int:
"""Return number of blocks needed per request at max_model_len.
Note: This assumes that all KV cache groups have the same block size.
"""
# All groups must have same block size. We take the spec from the first
# group as representative for block size and max memory usage calculation
# per layer.
spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
max_memory_per_layer = spec.max_memory_usage_bytes(vllm_config)
bytes_per_block_per_layer = spec.page_size_bytes
return cdiv(max_memory_per_layer, bytes_per_block_per_layer)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double-checked this path, and the current _blocks_per_request math is intentional for hybrid KV configs.
In vLLM, blocks are consumed per KV cache group for a request, so for hybrid models the request block demand is the sum across groups (e.g., full-attn blocks + sliding-window blocks), not just cdiv(max_model_len, block_size) from one representative group.
That is why _blocks_per_request computes:

  • numerator: num_layer_per_group * sum(group.max_memory_usage_bytes(...))
  • denominator: num_layer_per_group * page_size
    which simplifies to cdiv(sum_group_memory, page_size) (the num_layer_per_group factor cancels). This preserves existing concurrency semantics.

There is an existing test that reflects this behavior:
tests/v1/core/test_kv_cache_utils.py:1405 (kv_cache_config_hybrid_model) expects concurrency 3 for num_blocks=(1024 + 129) * 3, i.e., blocks/request is 1024 + 129, not 1024.

Also, the recent fix in this PR separates byte accounting for logging (_kv_cache_bytes_per_block) so UniformType no longer overcounts GiB, while keeping concurrency math unchanged.

@ashishkamra
Copy link
Copy Markdown
Author

Here are the results for the PR for two models on a RTX 4060 with 8GB of VRAM:

vllm serve Qwen/Qwen3-0.6B --max-model-len 2048 --max-num-seqs 64 --enforce-eager --dtype float16
INFO 03-28 23:24:11 [gpu_worker.py:436] Available KV cache memory: 5.52 GiB
INFO 03-28 23:24:11 [kv_cache_utils.py:1347] GPU KV cache size: 51,632 tokens
INFO 03-28 23:24:11 [kv_cache_utils.py:1352] Maximum concurrency for 2,048 tokens per request: 25.21x
INFO 03-28 23:24:11 [kv_cache_utils.py:1363] GPU KV cache memory: 5.51 GiB (3227 blocks)
INFO 03-28 23:24:11 [kv_cache_utils.py:1372] KV cache for 64 seqs x 2,048 tokens: 14.0 GiB (allocated: 5.51 GiB)
WARNING 03-28 23:24:11 [kv_cache_utils.py:1381] KV cache can hold 25 full-length sequences but max_num_seqs is 64. Sequences will queue when KV cache is full.

vllm serve facebook/opt-125m --max-model-len 2048 --max-num-seqs 16 --enforce-eager --dtype float16
INFO 03-28 23:25:12 [gpu_worker.py:436] Available KV cache memory: 6.51 GiB
INFO 03-28 23:25:12 [kv_cache_utils.py:1347] GPU KV cache size: 189,552 tokens
INFO 03-28 23:25:12 [kv_cache_utils.py:1352] Maximum concurrency for 2,048 tokens per request: 92.55x
INFO 03-28 23:25:12 [kv_cache_utils.py:1363] GPU KV cache memory: 6.51 GiB (11847 blocks)
INFO 03-28 23:25:12 [kv_cache_utils.py:1372] KV cache for 16 seqs x 2,048 tokens: 1.12 GiB (allocated: 6.51 GiB)
INFO 03-28 23:25:12 [core.py:283] init engine (profile, create kv cache, warmup model) took 5.71 seconds

Note: opt-125m had enough KV cache (92x concurrency vs 16 seqs requested) so no warning was emitted. Qwen3-0.6B triggered the warning because the GPU could only hold 25 full-length sequences but 64 were requested.

…ds capacity

- Extract _blocks_per_request() and _kv_cache_bytes_per_block() helpers
  from get_max_concurrency_for_kv_cache_config() to share block-demand
  math between concurrency calculation and the new log lines.
- Fix _kv_cache_bytes_per_block() for UniformTypeKVCacheSpecs where
  page_size_bytes already sums across all layers (avoids overcounting).
- Log GPU KV cache memory in GiB and block count after allocation.
- Log KV cache demand for max_num_seqs × max_model_len vs allocated.
- Warn when allocated KV cache cannot hold max_num_seqs full-length
  sequences so queueing behavior is explicit.

Signed-off-by: Ashish Kamra <ashishkamra@gmail.com>
@ashishkamra ashishkamra force-pushed the log-kv-cache-memory-gib branch from c056aff to 1ef1b4f Compare March 29, 2026 16:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant