Skip to content

[BugFix] Correct max memory usage for multiple KV-cache groups#36030

Merged
heheda12345 merged 4 commits intovllm-project:mainfrom
peakcrosser7:fix/max_mem_usage
Mar 17, 2026
Merged

[BugFix] Correct max memory usage for multiple KV-cache groups#36030
heheda12345 merged 4 commits intovllm-project:mainfrom
peakcrosser7:fix/max_mem_usage

Conversation

@peakcrosser7
Copy link
Copy Markdown
Contributor

@peakcrosser7 peakcrosser7 commented Mar 4, 2026

Purpose

This PR fixes a calculation error in _max_memory_usage_bytes_from_groups() that led to underestimating total memory usage in multi-group cases.

Root Cause

The original implementation only calculated blocks_needed based on the first KV-Cache group (index 0). In models with multiple KV-Cache groups (e.g., hybrid Mamba models), the block usage from other groups was neglected. This resulted in a lower-than-actual total memory estimate, also caused the engine to over-calculate the max_model_len.

For example, in test_auto_fit_max_model_len_with_hybrid(), the model consists of both Mamba and FullAttn groups. Since the original logic only considered the Mamba group (index 0), which occupies a fixed number of blocks when prefix-caching is disabled, it ignored the memory usage of the FullAttn group and led to an excessively large max_model_len.

Solution

Updated the logic to sum the blocks from all KV-Cache groups, ensuring the total memory usage accurately reflects the requirements of the entire model.

Test Plan

Added a new unit test test_auto_fit_max_model_len_with_hybrid()

pytest tests/v1/core/test_kv_cache_utils.py::test_auto_fit_max_model_len_with_hybrid

Test Result

Before fix:

_____________________________________________________ ERROR collecting tests/v1/core/test_kv_cache_utils.py ______________________________________________________
tests/v1/core/test_kv_cache_utils.py:2127: in <module>
    test_auto_fit_max_model_len_with_hybrid()
tests/v1/core/test_kv_cache_utils.py:2038: in test_auto_fit_max_model_len_with_hybrid
    assert vllm_config.model_config.max_model_len == 1024
E   AssertionError: assert 8192 == 1024
E    +  where 8192 = ModelConfig(model='Qwen/Qwen3-0.6B', model_weights='', runner='auto', convert='auto', tokenizer='Qwen/Qwen3-0.6B', tok...ide_attention_dtype=None, logits_processors=None, io_processor_plugin=None, pooler_config=None, multimodal_config=None).max_model_len
E    +    where ModelConfig(model='Qwen/Qwen3-0.6B', model_weights='', runner='auto', convert='auto', tokenizer='Qwen/Qwen3-0.6B', tok...ide_attention_dtype=None, logits_processors=None, io_processor_plugin=None, pooler_config=None, multimodal_config=None) = VllmConfig(model_config=ModelConfig(model='Qwen/Qwen3-0.6B', model_weights='', runner='auto', convert='auto', tokenize...6572044341470', optimization_level=<OptimizationLevel.O2: 2>, performance_mode='balanced', weight_transfer_config=None).model_config

After fix:

root@hhy_develop_vllm_opsrc:~/huanghy/vllm_opsrc# pytest tests/v1/core/test_kv_cache_utils.py::test_auto_fit_max_model_len_with_hybrid
====================================================================== test session starts =======================================================================
platform linux -- Python 3.11.13, pytest-8.4.2, pluggy-1.6.0
rootdir: /root/huanghy/vllm_opsrc
configfile: pyproject.toml
plugins: hypothesis-6.137.1, anyio-4.11.0
collected 1 item                                                                                                                                                 

tests/v1/core/test_kv_cache_utils.py .                                                                                                                     [100%]

======================================================================== warnings summary ========================================================================
<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

../../../opt/conda/lib/python3.11/site-packages/torch/jit/_script.py:362: 14 warnings
  /opt/conda/lib/python3.11/site-packages/torch/jit/_script.py:362: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================ 1 passed, 16 warnings in 13.92s =================================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
@mergify mergify Bot added v1 bug Something isn't working labels Mar 4, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in the maximum memory usage calculation, which is particularly relevant for hybrid models. The change in vllm/v1/core/kv_cache_utils.py correctly sums the memory requirements across all KV cache groups, whereas the previous implementation only considered the first group. This ensures accurate memory estimation, for instance when auto-fitting max_model_len. The addition of a new test case for hybrid KV cache specs in tests/v1/core/test_kv_cache_utils.py is a good measure to prevent regressions.

@peakcrosser7 peakcrosser7 marked this pull request as ready for review March 5, 2026 16:50
@peakcrosser7 peakcrosser7 changed the title [BugFix] fix max memory usage [BugFix] Correct max memory usage for multiple KV-cache groups Mar 5, 2026
Copy link
Copy Markdown
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@heheda12345 heheda12345 enabled auto-merge (squash) March 10, 2026 05:44
@github-actions github-actions Bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 10, 2026
@repne
Copy link
Copy Markdown

repne commented Mar 16, 2026

Hi @peakcrosser7, thank you for your work on this. I've been testing this PR without issues for the past week or so.
I was wondering, is this PR superseeded by #37124?

@swtb3
Copy link
Copy Markdown

swtb3 commented Mar 16, 2026

Hi @peakcrosser7, thank you for your work on this. I've been testing this PR without issues for the past week or so. I was wondering, is this PR superseeded by #37124?

Hello, im author of #37124 , I also fixed _max_memory_usage_bytes_from_groups as part of a broader set of changes to hybrid Mamba/attention KV cache handling (reporting, allocation, and concurrency estimation). There is overlap on that one function. Let me know how we should proceed :)

(my apologies as i did not see this PR as the issue that brought me here was to do with memory over estimation for the qwen 3.5 series.)

@peakcrosser7
Copy link
Copy Markdown
Contributor Author

Hi @repne, thanks for pointing that out! It seems there is indeed some overlap between this PR and #37124.

Hi @swtb3, great work on #37124! It looks like a more comprehensive solution to the problem. I'm not entirely sure about the best way to move forward with my PR, since the auto-merge wasn’t triggered as expected earlier.

@heheda12345, what are your thoughts on this?

Comment on lines +1359 to +1362
blocks_needed = sum(
cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size)
for group in kv_cache_groups
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a quick test I am getting OOM with #37124, but not with #36030. So the highlighted lines seems to be needed @swtb3. This on Blackwell + Qwen3.5-27B-FP8

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ive done some digging on the cause of the OOM. I think that to get the proper allocation for Qwen3.5 will have me back to the drawing board. It may not be as simple as I first thought. I would say, if this PR is ready and tested then go for it. I will rebase my PR on top and continue figuring it out. If youve any thoughts on the OOM lets discuss over on #37124

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will use this PR for the time being, ping me when you have something you want me to test. Thank you!

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@repne new changes pushed, could you test? cheers!

Copy link
Copy Markdown
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@heheda12345 heheda12345 merged commit 45f526d into vllm-project:main Mar 17, 2026
47 checks passed
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
…project#36030)

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
…project#36030)

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Mar 19, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…project#36030)

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…project#36030)

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
…project#36030)

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants